Unverified Commit 11dcabc5 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Grpc client (#9939)

parent 4d89389c
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use async_trait::async_trait;
use futures;
......@@ -6,6 +7,7 @@ use serde_json;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock};
use tokio::sync::Mutex;
// Shared HTTP client for worker operations (health checks, server info, etc.)
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
......@@ -249,7 +251,7 @@ pub struct WorkerMetadata {
}
/// Basic worker implementation
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct BasicWorker {
metadata: WorkerMetadata,
load_counter: Arc<AtomicUsize>,
......@@ -258,6 +260,19 @@ pub struct BasicWorker {
consecutive_failures: Arc<AtomicUsize>,
consecutive_successes: Arc<AtomicUsize>,
circuit_breaker: CircuitBreaker,
/// Optional gRPC client for gRPC workers
grpc_client: Option<Arc<Mutex<SglangSchedulerClient>>>,
}
impl fmt::Debug for BasicWorker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BasicWorker")
.field("metadata", &self.metadata)
.field("healthy", &self.healthy.load(Ordering::Relaxed))
.field("circuit_breaker", &self.circuit_breaker)
.field("has_grpc_client", &self.grpc_client.is_some())
.finish()
}
}
impl BasicWorker {
......@@ -286,6 +301,7 @@ impl BasicWorker {
consecutive_failures: Arc::new(AtomicUsize::new(0)),
consecutive_successes: Arc::new(AtomicUsize::new(0)),
circuit_breaker: CircuitBreaker::new(),
grpc_client: None,
}
}
......@@ -304,6 +320,12 @@ impl BasicWorker {
self
}
/// Set the gRPC client for gRPC workers
pub fn with_grpc_client(mut self, client: SglangSchedulerClient) -> Self {
self.grpc_client = Some(Arc::new(Mutex::new(client)));
self
}
pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
......@@ -352,15 +374,46 @@ impl Worker for BasicWorker {
async fn check_health_async(&self) -> WorkerResult<()> {
use std::time::Duration;
// Perform actual HTTP health check
let url = self.normalised_url()?;
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
let health_result = match &self.metadata.connection_mode {
ConnectionMode::Http => {
// Perform HTTP health check
let url = self.normalised_url()?;
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
ConnectionMode::Grpc { .. } => {
// Perform gRPC health check
if let Some(grpc_client) = &self.grpc_client {
let mut client = grpc_client.lock().await;
match client.health_check().await {
Ok(response) => {
tracing::debug!(
"gRPC health check succeeded for {}: healthy={}",
self.metadata.url,
response.healthy
);
response.healthy
}
Err(e) => {
tracing::warn!(
"gRPC health check RPC failed for {}: {:?}",
self.metadata.url,
e
);
false
}
}
} else {
tracing::error!("No gRPC client available for worker {}", self.metadata.url);
false
}
}
};
if health_result {
......@@ -390,7 +443,7 @@ impl Worker for BasicWorker {
}
Err(WorkerError::HealthCheckFailed {
url: url.to_string(),
url: self.metadata.url.clone(),
reason: format!("Health check failed (consecutive failures: {})", failures),
})
}
......@@ -1491,12 +1544,17 @@ mod tests {
// Clone for use inside catch_unwind
let worker_clone = Arc::clone(&worker);
// Use AssertUnwindSafe wrapper for the test
// This is safe because we're only testing the load counter behavior,
// not the grpc_client which is None for HTTP workers
use std::panic::AssertUnwindSafe;
// This will panic, but the guard should still clean up
let result = std::panic::catch_unwind(|| {
let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
let _guard = WorkerLoadGuard::new(worker_clone.as_ref());
assert_eq!(worker_clone.load(), 1);
panic!("Test panic");
});
}));
// Verify panic occurred
assert!(result.is_err());
......
......@@ -20,7 +20,14 @@ impl SglangSchedulerClient {
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
debug!("Connecting to SGLang scheduler at {}", endpoint);
let channel = Channel::from_shared(endpoint.to_string())?
// Convert grpc:// to http:// for tonic
let http_endpoint = if endpoint.starts_with("grpc://") {
endpoint.replace("grpc://", "http://")
} else {
endpoint.to_string()
};
let channel = Channel::from_shared(http_endpoint)?
.timeout(Duration::from_secs(30))
.connect()
.await?;
......@@ -59,11 +66,13 @@ impl SglangSchedulerClient {
pub async fn health_check(
&mut self,
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
debug!("Sending health check request");
let request = Request::new(proto::HealthCheckRequest {
include_detailed_metrics: false,
});
let response = self.client.health_check(request).await?;
debug!("Health check response received");
Ok(response.into_inner())
}
......
......@@ -108,9 +108,11 @@ impl GrpcRouter {
}
// Create Worker trait objects with gRPC connection mode
let workers: Vec<Box<dyn Worker>> = worker_urls
.iter()
.map(|url| {
let mut workers: Vec<Box<dyn Worker>> = Vec::new();
// Move clients from the HashMap to the workers
for url in &worker_urls {
if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Regular,
......@@ -123,10 +125,14 @@ impl GrpcRouter {
endpoint: health_check_config.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold,
success_threshold: health_check_config.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
})
.with_grpc_client(client);
workers.push(Box::new(worker) as Box<dyn Worker>);
} else {
warn!("No gRPC client for worker {}, skipping", url);
}
}
// Initialize policy with workers if needed
if let Some(cache_aware) = policy
......@@ -252,6 +258,11 @@ impl WorkerManagement for GrpcRouter {
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
vec![]
self.workers
.read()
.unwrap()
.iter()
.map(|w| w.url().to_string())
.collect()
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment