Unverified Commit 1b011e68 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] move grpc client from router to worker and builder (#10958)

parent 5c0efa56
...@@ -9,7 +9,7 @@ use serde_json; ...@@ -9,7 +9,7 @@ use serde_json;
use std::fmt; use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock}; use std::sync::{Arc, LazyLock};
use tokio::sync::Mutex; use tokio::sync::{Mutex, RwLock};
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| { static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder() reqwest::Client::builder()
...@@ -337,8 +337,8 @@ pub struct BasicWorker { ...@@ -337,8 +337,8 @@ pub struct BasicWorker {
pub consecutive_failures: Arc<AtomicUsize>, pub consecutive_failures: Arc<AtomicUsize>,
pub consecutive_successes: Arc<AtomicUsize>, pub consecutive_successes: Arc<AtomicUsize>,
pub circuit_breaker: CircuitBreaker, pub circuit_breaker: CircuitBreaker,
/// Optional gRPC client for gRPC workers /// Lazily initialized gRPC client for gRPC workers
pub grpc_client: Option<Arc<Mutex<SglangSchedulerClient>>>, pub grpc_client: Arc<RwLock<Option<Arc<Mutex<SglangSchedulerClient>>>>>,
} }
impl fmt::Debug for BasicWorker { impl fmt::Debug for BasicWorker {
...@@ -347,7 +347,7 @@ impl fmt::Debug for BasicWorker { ...@@ -347,7 +347,7 @@ impl fmt::Debug for BasicWorker {
.field("metadata", &self.metadata) .field("metadata", &self.metadata)
.field("healthy", &self.healthy.load(Ordering::Relaxed)) .field("healthy", &self.healthy.load(Ordering::Relaxed))
.field("circuit_breaker", &self.circuit_breaker) .field("circuit_breaker", &self.circuit_breaker)
.field("has_grpc_client", &self.grpc_client.is_some()) .field("grpc_client", &"<RwLock>")
.finish() .finish()
} }
} }
...@@ -421,7 +421,7 @@ impl Worker for BasicWorker { ...@@ -421,7 +421,7 @@ impl Worker for BasicWorker {
} }
} }
ConnectionMode::Grpc { .. } => { ConnectionMode::Grpc { .. } => {
// Use the new get_grpc_client() method // Use the new get_grpc_client() method for lazy initialization
match self.get_grpc_client().await { match self.get_grpc_client().await {
Ok(Some(grpc_client)) => { Ok(Some(grpc_client)) => {
let mut client = grpc_client.lock().await; let mut client = grpc_client.lock().await;
...@@ -532,19 +532,45 @@ impl Worker for BasicWorker { ...@@ -532,19 +532,45 @@ impl Worker for BasicWorker {
match self.metadata.connection_mode { match self.metadata.connection_mode {
ConnectionMode::Http => Ok(None), ConnectionMode::Http => Ok(None),
ConnectionMode::Grpc { .. } => { ConnectionMode::Grpc { .. } => {
// If we already have a client, return it {
if let Some(ref client) = self.grpc_client { let client_guard = self.grpc_client.read().await;
if let Some(ref client) = *client_guard {
return Ok(Some(client.clone()));
}
}
let mut client_guard = self.grpc_client.write().await;
if let Some(ref client) = *client_guard {
return Ok(Some(client.clone())); return Ok(Some(client.clone()));
} }
// For lazy initialization, we would need to change grpc_client to be mutable tracing::info!(
// For now, return error if no client exists (will be initialized during worker creation) "Lazily initializing gRPC client for worker: {}",
Err(WorkerError::ConnectionFailed { self.metadata.url
url: self.metadata.url.clone(), );
reason: match SglangSchedulerClient::connect(&self.metadata.url).await {
"gRPC client not initialized. Client should be set during worker creation" Ok(client) => {
.to_string(), let client_arc = Arc::new(Mutex::new(client));
}) *client_guard = Some(client_arc.clone());
tracing::info!(
"Successfully connected gRPC client for worker: {}",
self.metadata.url
);
Ok(Some(client_arc))
}
Err(e) => {
tracing::error!(
"Failed to connect gRPC client for worker {}: {}",
self.metadata.url,
e
);
Err(WorkerError::ConnectionFailed {
url: self.metadata.url.clone(),
reason: format!("Failed to connect to gRPC server: {}", e),
})
}
}
} }
} }
} }
...@@ -553,12 +579,11 @@ impl Worker for BasicWorker { ...@@ -553,12 +579,11 @@ impl Worker for BasicWorker {
match self.metadata.connection_mode { match self.metadata.connection_mode {
ConnectionMode::Http => Ok(()), ConnectionMode::Http => Ok(()),
ConnectionMode::Grpc { .. } => { ConnectionMode::Grpc { .. } => {
// For now, we can't reset the client since it's not mutable let mut client_guard = self.grpc_client.write().await;
// This would require changing the grpc_client field to use RwLock or OnceCell if client_guard.is_some() {
// which we'll do in a future iteration tracing::info!("Resetting gRPC client for worker: {}", self.metadata.url);
tracing::warn!( *client_guard = None;
"gRPC client reset not yet implemented - requires mutable client storage" }
);
Ok(()) Ok(())
} }
} }
......
...@@ -100,7 +100,7 @@ impl BasicWorkerBuilder { ...@@ -100,7 +100,7 @@ impl BasicWorkerBuilder {
atomic::{AtomicBool, AtomicUsize}, atomic::{AtomicBool, AtomicUsize},
Arc, Arc,
}; };
use tokio::sync::Mutex; use tokio::sync::{Mutex, RwLock};
let metadata = WorkerMetadata { let metadata = WorkerMetadata {
url: self.url.clone(), url: self.url.clone(),
...@@ -111,6 +111,10 @@ impl BasicWorkerBuilder { ...@@ -111,6 +111,10 @@ impl BasicWorkerBuilder {
health_config: self.health_config, health_config: self.health_config,
}; };
let grpc_client = Arc::new(RwLock::new(
self.grpc_client.map(|client| Arc::new(Mutex::new(client))),
));
BasicWorker { BasicWorker {
metadata, metadata,
load_counter: Arc::new(AtomicUsize::new(0)), load_counter: Arc::new(AtomicUsize::new(0)),
...@@ -119,7 +123,7 @@ impl BasicWorkerBuilder { ...@@ -119,7 +123,7 @@ impl BasicWorkerBuilder {
consecutive_failures: Arc::new(AtomicUsize::new(0)), consecutive_failures: Arc::new(AtomicUsize::new(0)),
consecutive_successes: Arc::new(AtomicUsize::new(0)), consecutive_successes: Arc::new(AtomicUsize::new(0)),
circuit_breaker: CircuitBreaker::with_config(self.circuit_breaker_config), circuit_breaker: CircuitBreaker::with_config(self.circuit_breaker_config),
grpc_client: self.grpc_client.map(|client| Arc::new(Mutex::new(client))), grpc_client,
} }
} }
} }
......
...@@ -4,7 +4,6 @@ use crate::config::types::RetryConfig; ...@@ -4,7 +4,6 @@ use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
}; };
use crate::grpc_client::SglangSchedulerClient;
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
...@@ -18,10 +17,9 @@ use axum::{ ...@@ -18,10 +17,9 @@ use axum::{
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tracing::{info, warn}; use tracing::info;
/// gRPC PD (Prefill-Decode) router implementation for SGLang /// gRPC PD (Prefill-Decode) router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete #[allow(dead_code)] // Fields will be used once implementation is complete
...@@ -89,86 +87,55 @@ impl GrpcPDRouter { ...@@ -89,86 +87,55 @@ impl GrpcPDRouter {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
}; };
// Create gRPC clients for prefill workers
let mut prefill_grpc_clients = HashMap::new();
for (url, _bootstrap_port) in &prefill_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
prefill_grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC prefill worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
// Continue with other workers
}
}
}
// Create gRPC clients for decode workers
let mut decode_grpc_clients = HashMap::new();
for url in &decode_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
decode_grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC decode worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
// Continue with other workers
}
}
}
if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
return Err("Failed to connect to any gRPC workers".to_string());
}
// Create Prefill Worker trait objects with gRPC connection mode and register them
for (url, bootstrap_port) in &prefill_urls { for (url, bootstrap_port) in &prefill_urls {
if let Some(client) = prefill_grpc_clients.remove(url) { let worker = BasicWorkerBuilder::new(url.clone())
let worker = BasicWorkerBuilder::new(url.clone()) .worker_type(WorkerType::Prefill {
.worker_type(WorkerType::Prefill { bootstrap_port: *bootstrap_port,
bootstrap_port: *bootstrap_port, })
}) .connection_mode(crate::core::ConnectionMode::Grpc {
.connection_mode(crate::core::ConnectionMode::Grpc { port: *bootstrap_port,
port: *bootstrap_port, })
}) .circuit_breaker_config(core_cb_config.clone())
.circuit_breaker_config(core_cb_config.clone()) .health_config(HealthConfig {
.health_config(HealthConfig { timeout_secs: ctx.router_config.health_check.timeout_secs,
timeout_secs: ctx.router_config.health_check.timeout_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs, endpoint: ctx.router_config.health_check.endpoint.clone(),
endpoint: ctx.router_config.health_check.endpoint.clone(), failure_threshold: ctx.router_config.health_check.failure_threshold,
failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, })
}) // No longer passing pre-initialized client - will be created lazily
.grpc_client(client) .build();
.build();
worker_registry.register(Arc::new(worker));
// Register worker in the centralized registry info!(
worker_registry.register(Arc::new(worker)); "Registered gRPC prefill worker at {} (will connect on first use)",
} url
);
} }
// Create Decode Worker trait objects with gRPC connection mode and register them
for url in &decode_urls { for url in &decode_urls {
if let Some(client) = decode_grpc_clients.remove(url) { let worker = BasicWorkerBuilder::new(url.clone())
let worker = BasicWorkerBuilder::new(url.clone()) .worker_type(WorkerType::Decode)
.worker_type(WorkerType::Decode) .connection_mode(crate::core::ConnectionMode::Grpc { port: None })
.connection_mode(crate::core::ConnectionMode::Grpc { port: None }) .circuit_breaker_config(core_cb_config.clone())
.circuit_breaker_config(core_cb_config.clone()) .health_config(HealthConfig {
.health_config(HealthConfig { timeout_secs: ctx.router_config.health_check.timeout_secs,
timeout_secs: ctx.router_config.health_check.timeout_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs, endpoint: ctx.router_config.health_check.endpoint.clone(),
endpoint: ctx.router_config.health_check.endpoint.clone(), failure_threshold: ctx.router_config.health_check.failure_threshold,
failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, })
}) .build();
.grpc_client(client)
.build(); worker_registry.register(Arc::new(worker));
info!(
"Registered gRPC decode worker at {} (will connect on first use)",
url
);
}
// Register worker in the centralized registry if prefill_urls.is_empty() && decode_urls.is_empty() {
worker_registry.register(Arc::new(worker)); return Err("No gRPC workers configured".to_string());
}
} }
// Initialize policies with workers if needed - filter for gRPC workers only // Initialize policies with workers if needed - filter for gRPC workers only
......
// gRPC Router Implementation // gRPC Router Implementation
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
...@@ -96,51 +95,35 @@ impl GrpcRouter { ...@@ -96,51 +95,35 @@ impl GrpcRouter {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
}; };
// Create gRPC clients for each worker
let mut grpc_clients = HashMap::new();
for url in &worker_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC worker at {}: {}", url, e);
// Continue with other workers
}
}
}
if grpc_clients.is_empty() {
return Err("Failed to connect to any gRPC workers".to_string());
}
// Get registries from context // Get registries from context
let worker_registry = ctx.worker_registry.clone(); let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone(); let policy_registry = ctx.policy_registry.clone();
// Create Worker trait objects with gRPC connection mode and register them // Create Worker trait objects with gRPC connection mode and register them
// Workers will lazily initialize their gRPC clients on first use
for url in &worker_urls { for url in &worker_urls {
if let Some(client) = grpc_clients.remove(url) { let worker = BasicWorkerBuilder::new(url.clone())
let worker = BasicWorkerBuilder::new(url.clone()) .worker_type(WorkerType::Regular)
.worker_type(WorkerType::Regular) .connection_mode(crate::core::ConnectionMode::Grpc { port: None })
.connection_mode(crate::core::ConnectionMode::Grpc { port: None }) .circuit_breaker_config(core_cb_config.clone())
.circuit_breaker_config(core_cb_config.clone()) .health_config(HealthConfig {
.health_config(HealthConfig { timeout_secs: ctx.router_config.health_check.timeout_secs,
timeout_secs: ctx.router_config.health_check.timeout_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs, endpoint: ctx.router_config.health_check.endpoint.clone(),
endpoint: ctx.router_config.health_check.endpoint.clone(), failure_threshold: ctx.router_config.health_check.failure_threshold,
failure_threshold: ctx.router_config.health_check.failure_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, })
}) .build();
.grpc_client(client)
.build();
// Register worker in the centralized registry worker_registry.register(Arc::new(worker));
worker_registry.register(Arc::new(worker)); info!(
} else { "Registered gRPC worker at {} (will connect on first use)",
warn!("No gRPC client for worker {}, skipping", url); url
} );
}
if worker_urls.is_empty() {
return Err("No gRPC workers configured".to_string());
} }
// Get only gRPC workers from registry for policy initialization // Get only gRPC workers from registry for policy initialization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment