Unverified Commit c8845b41 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

fix: worker to graceful shutdown after finishing in-flight requests (#4838)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarBiswa Panda <biswa.panda@gmail.com>
parent 69817c2d
......@@ -60,7 +60,11 @@ impl MockVllmEngine {
}
pub async fn start(&self, component: Component) -> Result<()> {
let cancel_token = component.drt().runtime().child_token();
// Use primary_token() instead of child_token() so the mocker continues running
// during graceful shutdown (Phase 1/2) and only stops in Phase 3.
// child_token() is a child of endpoint_shutdown_token which is cancelled in Phase 1.
// primary_token() is only cancelled in Phase 3, after waiting for inflight requests.
let cancel_token = component.drt().primary_token();
// Simulate engine startup time if configured
if let Some(startup_time_secs) = self.engine_args.startup_time {
......@@ -143,6 +147,11 @@ impl MockVllmEngine {
}
}
_ = cancel_token_cloned.cancelled() => {
tracing::info!("Scheduler output task cancelled, clearing active requests");
// Clear all active requests to unblock waiting request handlers
// This will cause their request_rx.recv() to return None
let mut active = active_requests_clone.lock().await;
active.clear();
break;
}
}
......
......@@ -105,7 +105,27 @@ impl SharedHttpServer {
.system_health
.lock()
.set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
tracing::debug!("Unregistered endpoint handler for subject: {}", subject);
tracing::debug!(
endpoint_name = %endpoint_name,
subject = %subject,
"Unregistered HTTP endpoint handler"
);
let inflight_count = handler.inflight.load(Ordering::SeqCst);
if inflight_count > 0 {
tracing::info!(
endpoint_name = %endpoint_name,
inflight_count = inflight_count,
"Waiting for inflight HTTP requests to complete"
);
while handler.inflight.load(Ordering::SeqCst) > 0 {
handler.notify.notified().await;
}
tracing::info!(
endpoint_name = %endpoint_name,
"All inflight HTTP requests completed"
);
}
}
}
......
......@@ -32,6 +32,7 @@ pub struct NatsMultiplexedServer {
struct EndpointTask {
cancel_token: CancellationToken,
join_handle: tokio::task::JoinHandle<()>,
_endpoint_name: String,
}
......@@ -145,7 +146,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
// Spawn task to handle this endpoint using PushEndpoint
// Note: PushEndpoint::start() is a blocking loop that runs until cancelled
let endpoint_name_clone = endpoint_name.clone();
tokio::spawn(async move {
let join_handle = tokio::spawn(async move {
if let Err(e) = push_endpoint
.start(
service_endpoint,
......@@ -180,6 +181,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
endpoint_name.clone(),
EndpointTask {
cancel_token: endpoint_cancel,
join_handle,
_endpoint_name: endpoint_name,
},
);
......@@ -193,7 +195,25 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
endpoint_name = %endpoint_name,
"Unregistering NATS endpoint"
);
// Cancel the token to trigger graceful shutdown
task.cancel_token.cancel();
// Wait for the endpoint task to complete (which includes waiting for inflight requests)
tracing::debug!(
endpoint_name = %endpoint_name,
"Waiting for NATS endpoint task to complete"
);
if let Err(e) = task.join_handle.await {
tracing::warn!(
endpoint_name = %endpoint_name,
error = %e,
"NATS endpoint task panicked during shutdown"
);
}
tracing::info!(
endpoint_name = %endpoint_name,
"NATS endpoint unregistration complete"
);
}
Ok(())
}
......
......@@ -135,16 +135,26 @@ impl PushEndpoint {
// await for all inflight requests to complete if graceful shutdown
if self.graceful_shutdown {
tracing::info!(
"Waiting for {} inflight requests to complete",
inflight.load(Ordering::SeqCst)
);
while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
let inflight_count = inflight.load(Ordering::SeqCst);
if inflight_count > 0 {
tracing::info!(
endpoint_name = endpoint_name_local.as_str(),
inflight_count = inflight_count,
"Waiting for inflight NATS requests to complete"
);
while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
}
tracing::info!(
endpoint_name = endpoint_name_local.as_str(),
"All inflight NATS requests completed"
);
}
tracing::info!("All inflight requests completed");
} else {
tracing::info!("Skipping graceful shutdown, not waiting for inflight requests");
tracing::info!(
endpoint_name = endpoint_name_local.as_str(),
"Skipping graceful shutdown, not waiting for inflight requests"
);
}
Ok(())
......
......@@ -100,11 +100,33 @@ impl SharedTcpServer {
}
pub async fn unregister_endpoint(&self, endpoint_path: &str, endpoint_name: &str) {
self.handlers.remove(endpoint_path);
tracing::info!(
"Unregistered endpoint '{}' from shared TCP server",
endpoint_name
);
if let Some((_, handler)) = self.handlers.remove(endpoint_path) {
handler
.system_health
.lock()
.set_endpoint_health_status(endpoint_name, crate::HealthStatus::NotReady);
tracing::info!(
endpoint_name = %endpoint_name,
endpoint_path = %endpoint_path,
"Unregistered TCP endpoint handler"
);
let inflight_count = handler.inflight.load(Ordering::SeqCst);
if inflight_count > 0 {
tracing::info!(
endpoint_name = %endpoint_name,
inflight_count = inflight_count,
"Waiting for inflight TCP requests to complete"
);
while handler.inflight.load(Ordering::SeqCst) > 0 {
handler.notify.notified().await;
}
tracing::info!(
endpoint_name = %endpoint_name,
"All inflight TCP requests completed"
);
}
}
}
pub async fn start(self: Arc<Self>) -> Result<()> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment