Unverified Commit ca240eef authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support parallel queue puts in grpc_request_manager and remove...

[router][grpc] Support parallel queue puts in grpc_request_manager and remove mutex for grpc_client (#11798)
parent 6c7c92eb
...@@ -443,10 +443,11 @@ class GrpcRequestManager: ...@@ -443,10 +443,11 @@ class GrpcRequestManager:
recv_obj = await self.recv_from_scheduler.recv_pyobj() recv_obj = await self.recv_from_scheduler.recv_pyobj()
self.last_receive_tstamp = time.time() self.last_receive_tstamp = time.time()
# Check for pause # Check for pause (optimized: check flag before acquiring lock)
async with self.is_pause_cond: if self.is_pause:
while self.is_pause: async with self.is_pause_cond:
await self.is_pause_cond.wait() while self.is_pause:
await self.is_pause_cond.wait()
# Handle different output types # Handle different output types
if isinstance(recv_obj, BatchTokenIDOutput): if isinstance(recv_obj, BatchTokenIDOutput):
...@@ -531,6 +532,11 @@ class GrpcRequestManager: ...@@ -531,6 +532,11 @@ class GrpcRequestManager:
async def _handle_batch_output(self, batch_out: BatchTokenIDOutput): async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
"""Handle batch generation output from scheduler.""" """Handle batch generation output from scheduler."""
# Collect all queue.put() tasks for parallel execution
put_tasks = []
cleanup_tasks = []
now = time.time()
# Process each request in the batch # Process each request in the batch
for i, rid in enumerate(batch_out.rids): for i, rid in enumerate(batch_out.rids):
if rid not in self.rid_to_state: if rid not in self.rid_to_state:
...@@ -544,7 +550,6 @@ class GrpcRequestManager: ...@@ -544,7 +550,6 @@ class GrpcRequestManager:
continue continue
# Update metrics # Update metrics
now = time.time()
if state.first_token_time == 0.0: if state.first_token_time == 0.0:
state.first_token_time = now state.first_token_time = now
state.last_time = now state.last_time = now
...@@ -638,7 +643,8 @@ class GrpcRequestManager: ...@@ -638,7 +643,8 @@ class GrpcRequestManager:
if output_data["token_ids"]: if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"]) state.output_ids.extend(output_data["token_ids"])
await state.out_queue.put(output_data) # Add queue.put() to parallel task list
put_tasks.append(state.out_queue.put(output_data))
# Handle completion # Handle completion
if output_data["finished"]: if output_data["finished"]:
...@@ -648,12 +654,16 @@ class GrpcRequestManager: ...@@ -648,12 +654,16 @@ class GrpcRequestManager:
state.event.set() state.event.set()
# Remove from tracking after a delay # Remove from tracking after a delay
async def cleanup(): async def cleanup(request_id):
await asyncio.sleep(5.0) await asyncio.sleep(5.0)
if rid in self.rid_to_state: if request_id in self.rid_to_state:
del self.rid_to_state[rid] del self.rid_to_state[request_id]
cleanup_tasks.append(asyncio.create_task(cleanup(rid)))
asyncio.create_task(cleanup()) # Execute all queue.put() operations in parallel
if put_tasks:
await asyncio.gather(*put_tasks, return_exceptions=True)
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput): async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
"""Handle batch embedding output from scheduler.""" """Handle batch embedding output from scheduler."""
......
...@@ -10,10 +10,7 @@ use std::{ ...@@ -10,10 +10,7 @@ use std::{
use async_trait::async_trait; use async_trait::async_trait;
use futures; use futures;
use serde_json; use serde_json;
use tokio::{ use tokio::{sync::RwLock, time};
sync::{Mutex, RwLock},
time,
};
use super::{CircuitBreaker, WorkerError, WorkerResult}; use super::{CircuitBreaker, WorkerError, WorkerResult};
use crate::{ use crate::{
...@@ -232,7 +229,7 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -232,7 +229,7 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Get or create a gRPC client for this worker /// Get or create a gRPC client for this worker
/// Returns None for HTTP workers, Some(client) for gRPC workers /// Returns None for HTTP workers, Some(client) for gRPC workers
async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<Mutex<SglangSchedulerClient>>>>; async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<SglangSchedulerClient>>>;
/// Reset the gRPC client connection (for reconnection scenarios) /// Reset the gRPC client connection (for reconnection scenarios)
/// No-op for HTTP workers /// No-op for HTTP workers
...@@ -367,7 +364,7 @@ pub struct BasicWorker { ...@@ -367,7 +364,7 @@ pub struct BasicWorker {
pub consecutive_successes: Arc<AtomicUsize>, pub consecutive_successes: Arc<AtomicUsize>,
pub circuit_breaker: CircuitBreaker, pub circuit_breaker: CircuitBreaker,
/// Lazily initialized gRPC client for gRPC workers /// Lazily initialized gRPC client for gRPC workers
pub grpc_client: Arc<RwLock<Option<Arc<Mutex<SglangSchedulerClient>>>>>, pub grpc_client: Arc<RwLock<Option<Arc<SglangSchedulerClient>>>>,
} }
impl fmt::Debug for BasicWorker { impl fmt::Debug for BasicWorker {
...@@ -505,7 +502,7 @@ impl Worker for BasicWorker { ...@@ -505,7 +502,7 @@ impl Worker for BasicWorker {
&self.circuit_breaker &self.circuit_breaker
} }
async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<Mutex<SglangSchedulerClient>>>> { async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<SglangSchedulerClient>>> {
match self.metadata.connection_mode { match self.metadata.connection_mode {
ConnectionMode::Http => Ok(None), ConnectionMode::Http => Ok(None),
ConnectionMode::Grpc { .. } => { ConnectionMode::Grpc { .. } => {
...@@ -528,7 +525,7 @@ impl Worker for BasicWorker { ...@@ -528,7 +525,7 @@ impl Worker for BasicWorker {
); );
match SglangSchedulerClient::connect(&self.metadata.url).await { match SglangSchedulerClient::connect(&self.metadata.url).await {
Ok(client) => { Ok(client) => {
let client_arc = Arc::new(Mutex::new(client)); let client_arc = Arc::new(client);
*client_guard = Some(client_arc.clone()); *client_guard = Some(client_arc.clone());
tracing::info!( tracing::info!(
"Successfully connected gRPC client for worker: {}", "Successfully connected gRPC client for worker: {}",
...@@ -577,8 +574,7 @@ impl Worker for BasicWorker { ...@@ -577,8 +574,7 @@ impl Worker for BasicWorker {
return Ok(false); return Ok(false);
}; };
let client = grpc_client.lock().await; match time::timeout(timeout, grpc_client.health_check()).await {
match time::timeout(timeout, client.health_check()).await {
Ok(Ok(resp)) => { Ok(Ok(resp)) => {
tracing::debug!( tracing::debug!(
"gRPC health OK for {}: healthy={}", "gRPC health OK for {}: healthy={}",
...@@ -749,7 +745,7 @@ impl Worker for DPAwareWorker { ...@@ -749,7 +745,7 @@ impl Worker for DPAwareWorker {
format!("{}{}", self.base_url, route) format!("{}{}", self.base_url, route)
} }
async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<Mutex<SglangSchedulerClient>>>> { async fn get_grpc_client(&self) -> WorkerResult<Option<Arc<SglangSchedulerClient>>> {
self.base_worker.get_grpc_client().await self.base_worker.get_grpc_client().await
} }
......
...@@ -104,7 +104,7 @@ impl BasicWorkerBuilder { ...@@ -104,7 +104,7 @@ impl BasicWorkerBuilder {
Arc, Arc,
}; };
use tokio::sync::{Mutex, RwLock}; use tokio::sync::RwLock;
let bootstrap_host = match url::Url::parse(&self.url) { let bootstrap_host = match url::Url::parse(&self.url) {
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(), Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
...@@ -145,9 +145,7 @@ impl BasicWorkerBuilder { ...@@ -145,9 +145,7 @@ impl BasicWorkerBuilder {
bootstrap_port, bootstrap_port,
}; };
let grpc_client = Arc::new(RwLock::new( let grpc_client = Arc::new(RwLock::new(self.grpc_client.map(Arc::new)));
self.grpc_client.map(|client| Arc::new(Mutex::new(client))),
));
BasicWorker { BasicWorker {
metadata, metadata,
......
...@@ -42,8 +42,7 @@ pub async fn get_grpc_client_from_worker( ...@@ -42,8 +42,7 @@ pub async fn get_grpc_client_from_worker(
.map_err(|e| internal_error_message(format!("Failed to get gRPC client: {}", e)))? .map_err(|e| internal_error_message(format!("Failed to get gRPC client: {}", e)))?
.ok_or_else(|| internal_error_static("Selected worker is not configured for gRPC"))?; .ok_or_else(|| internal_error_static("Selected worker is not configured for gRPC"))?;
let client = client_arc.lock().await.clone(); Ok((*client_arc).clone())
Ok(client)
} }
/// Process tool call arguments in messages /// Process tool call arguments in messages
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment