"...hubert/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "ddb9fb5bbf26ec8cbf50004dfc252792ab13bd54"
Unverified Commit 11dcabc5 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Grpc client (#9939)

parent 4d89389c
use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult}; use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use async_trait::async_trait; use async_trait::async_trait;
use futures; use futures;
...@@ -6,6 +7,7 @@ use serde_json; ...@@ -6,6 +7,7 @@ use serde_json;
use std::fmt; use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock}; use std::sync::{Arc, LazyLock};
use tokio::sync::Mutex;
// Shared HTTP client for worker operations (health checks, server info, etc.) // Shared HTTP client for worker operations (health checks, server info, etc.)
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| { static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
...@@ -249,7 +251,7 @@ pub struct WorkerMetadata { ...@@ -249,7 +251,7 @@ pub struct WorkerMetadata {
} }
/// Basic worker implementation /// Basic worker implementation
#[derive(Debug, Clone)] #[derive(Clone)]
pub struct BasicWorker { pub struct BasicWorker {
metadata: WorkerMetadata, metadata: WorkerMetadata,
load_counter: Arc<AtomicUsize>, load_counter: Arc<AtomicUsize>,
...@@ -258,6 +260,19 @@ pub struct BasicWorker { ...@@ -258,6 +260,19 @@ pub struct BasicWorker {
consecutive_failures: Arc<AtomicUsize>, consecutive_failures: Arc<AtomicUsize>,
consecutive_successes: Arc<AtomicUsize>, consecutive_successes: Arc<AtomicUsize>,
circuit_breaker: CircuitBreaker, circuit_breaker: CircuitBreaker,
/// Optional gRPC client for gRPC workers
grpc_client: Option<Arc<Mutex<SglangSchedulerClient>>>,
}
impl fmt::Debug for BasicWorker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BasicWorker")
.field("metadata", &self.metadata)
.field("healthy", &self.healthy.load(Ordering::Relaxed))
.field("circuit_breaker", &self.circuit_breaker)
.field("has_grpc_client", &self.grpc_client.is_some())
.finish()
}
} }
impl BasicWorker { impl BasicWorker {
...@@ -286,6 +301,7 @@ impl BasicWorker { ...@@ -286,6 +301,7 @@ impl BasicWorker {
consecutive_failures: Arc::new(AtomicUsize::new(0)), consecutive_failures: Arc::new(AtomicUsize::new(0)),
consecutive_successes: Arc::new(AtomicUsize::new(0)), consecutive_successes: Arc::new(AtomicUsize::new(0)),
circuit_breaker: CircuitBreaker::new(), circuit_breaker: CircuitBreaker::new(),
grpc_client: None,
} }
} }
...@@ -304,6 +320,12 @@ impl BasicWorker { ...@@ -304,6 +320,12 @@ impl BasicWorker {
self self
} }
/// Set the gRPC client for gRPC workers
pub fn with_grpc_client(mut self, client: SglangSchedulerClient) -> Self {
self.grpc_client = Some(Arc::new(Mutex::new(client)));
self
}
pub fn normalised_url(&self) -> WorkerResult<&str> { pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") { if self.url().contains("@") {
// Need to extract the URL from "http://host:port@dp_rank" // Need to extract the URL from "http://host:port@dp_rank"
...@@ -352,15 +374,46 @@ impl Worker for BasicWorker { ...@@ -352,15 +374,46 @@ impl Worker for BasicWorker {
async fn check_health_async(&self) -> WorkerResult<()> { async fn check_health_async(&self) -> WorkerResult<()> {
use std::time::Duration; use std::time::Duration;
// Perform actual HTTP health check let health_result = match &self.metadata.connection_mode {
let url = self.normalised_url()?; ConnectionMode::Http => {
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint); // Perform HTTP health check
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs); let url = self.normalised_url()?;
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
// Use the shared client with a custom timeout for this request let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
let health_result = match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
Ok(response) => response.status().is_success(), // Use the shared client with a custom timeout for this request
Err(_) => false, match WORKER_CLIENT.get(&health_url).timeout(timeout).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
ConnectionMode::Grpc { .. } => {
// Perform gRPC health check
if let Some(grpc_client) = &self.grpc_client {
let mut client = grpc_client.lock().await;
match client.health_check().await {
Ok(response) => {
tracing::debug!(
"gRPC health check succeeded for {}: healthy={}",
self.metadata.url,
response.healthy
);
response.healthy
}
Err(e) => {
tracing::warn!(
"gRPC health check RPC failed for {}: {:?}",
self.metadata.url,
e
);
false
}
}
} else {
tracing::error!("No gRPC client available for worker {}", self.metadata.url);
false
}
}
}; };
if health_result { if health_result {
...@@ -390,7 +443,7 @@ impl Worker for BasicWorker { ...@@ -390,7 +443,7 @@ impl Worker for BasicWorker {
} }
Err(WorkerError::HealthCheckFailed { Err(WorkerError::HealthCheckFailed {
url: url.to_string(), url: self.metadata.url.clone(),
reason: format!("Health check failed (consecutive failures: {})", failures), reason: format!("Health check failed (consecutive failures: {})", failures),
}) })
} }
...@@ -1491,12 +1544,17 @@ mod tests { ...@@ -1491,12 +1544,17 @@ mod tests {
// Clone for use inside catch_unwind // Clone for use inside catch_unwind
let worker_clone = Arc::clone(&worker); let worker_clone = Arc::clone(&worker);
// Use AssertUnwindSafe wrapper for the test
// This is safe because we're only testing the load counter behavior,
// not the grpc_client which is None for HTTP workers
use std::panic::AssertUnwindSafe;
// This will panic, but the guard should still clean up // This will panic, but the guard should still clean up
let result = std::panic::catch_unwind(|| { let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
let _guard = WorkerLoadGuard::new(worker_clone.as_ref()); let _guard = WorkerLoadGuard::new(worker_clone.as_ref());
assert_eq!(worker_clone.load(), 1); assert_eq!(worker_clone.load(), 1);
panic!("Test panic"); panic!("Test panic");
}); }));
// Verify panic occurred // Verify panic occurred
assert!(result.is_err()); assert!(result.is_err());
......
...@@ -20,7 +20,14 @@ impl SglangSchedulerClient { ...@@ -20,7 +20,14 @@ impl SglangSchedulerClient {
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> { pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
debug!("Connecting to SGLang scheduler at {}", endpoint); debug!("Connecting to SGLang scheduler at {}", endpoint);
let channel = Channel::from_shared(endpoint.to_string())? // Convert grpc:// to http:// for tonic
let http_endpoint = if endpoint.starts_with("grpc://") {
endpoint.replace("grpc://", "http://")
} else {
endpoint.to_string()
};
let channel = Channel::from_shared(http_endpoint)?
.timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(30))
.connect() .connect()
.await?; .await?;
...@@ -59,11 +66,13 @@ impl SglangSchedulerClient { ...@@ -59,11 +66,13 @@ impl SglangSchedulerClient {
pub async fn health_check( pub async fn health_check(
&mut self, &mut self,
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> { ) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
debug!("Sending health check request");
let request = Request::new(proto::HealthCheckRequest { let request = Request::new(proto::HealthCheckRequest {
include_detailed_metrics: false, include_detailed_metrics: false,
}); });
let response = self.client.health_check(request).await?; let response = self.client.health_check(request).await?;
debug!("Health check response received");
Ok(response.into_inner()) Ok(response.into_inner())
} }
......
...@@ -108,9 +108,11 @@ impl GrpcRouter { ...@@ -108,9 +108,11 @@ impl GrpcRouter {
} }
// Create Worker trait objects with gRPC connection mode // Create Worker trait objects with gRPC connection mode
let workers: Vec<Box<dyn Worker>> = worker_urls let mut workers: Vec<Box<dyn Worker>> = Vec::new();
.iter()
.map(|url| { // Move clients from the HashMap to the workers
for url in &worker_urls {
if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorker::with_connection_mode( let worker = BasicWorker::with_connection_mode(
url.clone(), url.clone(),
WorkerType::Regular, WorkerType::Regular,
...@@ -123,10 +125,14 @@ impl GrpcRouter { ...@@ -123,10 +125,14 @@ impl GrpcRouter {
endpoint: health_check_config.endpoint.clone(), endpoint: health_check_config.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold, failure_threshold: health_check_config.failure_threshold,
success_threshold: health_check_config.success_threshold, success_threshold: health_check_config.success_threshold,
}); })
Box::new(worker) as Box<dyn Worker> .with_grpc_client(client);
})
.collect(); workers.push(Box::new(worker) as Box<dyn Worker>);
} else {
warn!("No gRPC client for worker {}, skipping", url);
}
}
// Initialize policy with workers if needed // Initialize policy with workers if needed
if let Some(cache_aware) = policy if let Some(cache_aware) = policy
...@@ -252,6 +258,11 @@ impl WorkerManagement for GrpcRouter { ...@@ -252,6 +258,11 @@ impl WorkerManagement for GrpcRouter {
fn remove_worker(&self, _worker_url: &str) {} fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> { fn get_worker_urls(&self) -> Vec<String> {
vec![] self.workers
.read()
.unwrap()
.iter()
.map(|w| w.url().to_string())
.collect()
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment