Unverified Commit c8845b41 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

fix: worker to graceful shutdown after finishing in-flight requests (#4838)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarBiswa Panda <biswa.panda@gmail.com>
parent 69817c2d
...@@ -60,7 +60,11 @@ impl MockVllmEngine { ...@@ -60,7 +60,11 @@ impl MockVllmEngine {
} }
pub async fn start(&self, component: Component) -> Result<()> { pub async fn start(&self, component: Component) -> Result<()> {
let cancel_token = component.drt().runtime().child_token(); // Use primary_token() instead of child_token() so the mocker continues running
// during graceful shutdown (Phase 1/2) and only stops in Phase 3.
// child_token() is a child of endpoint_shutdown_token which is cancelled in Phase 1.
// primary_token() is only cancelled in Phase 3, after waiting for inflight requests.
let cancel_token = component.drt().primary_token();
// Simulate engine startup time if configured // Simulate engine startup time if configured
if let Some(startup_time_secs) = self.engine_args.startup_time { if let Some(startup_time_secs) = self.engine_args.startup_time {
...@@ -143,6 +147,11 @@ impl MockVllmEngine { ...@@ -143,6 +147,11 @@ impl MockVllmEngine {
} }
} }
_ = cancel_token_cloned.cancelled() => { _ = cancel_token_cloned.cancelled() => {
tracing::info!("Scheduler output task cancelled, clearing active requests");
// Clear all active requests to unblock waiting request handlers
// This will cause their request_rx.recv() to return None
let mut active = active_requests_clone.lock().await;
active.clear();
break; break;
} }
} }
......
...@@ -105,7 +105,27 @@ impl SharedHttpServer { ...@@ -105,7 +105,27 @@ impl SharedHttpServer {
.system_health .system_health
.lock() .lock()
.set_endpoint_health_status(endpoint_name, HealthStatus::NotReady); .set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
tracing::debug!("Unregistered endpoint handler for subject: {}", subject); tracing::debug!(
endpoint_name = %endpoint_name,
subject = %subject,
"Unregistered HTTP endpoint handler"
);
let inflight_count = handler.inflight.load(Ordering::SeqCst);
if inflight_count > 0 {
tracing::info!(
endpoint_name = %endpoint_name,
inflight_count = inflight_count,
"Waiting for inflight HTTP requests to complete"
);
while handler.inflight.load(Ordering::SeqCst) > 0 {
handler.notify.notified().await;
}
tracing::info!(
endpoint_name = %endpoint_name,
"All inflight HTTP requests completed"
);
}
} }
} }
......
...@@ -32,6 +32,7 @@ pub struct NatsMultiplexedServer { ...@@ -32,6 +32,7 @@ pub struct NatsMultiplexedServer {
struct EndpointTask { struct EndpointTask {
cancel_token: CancellationToken, cancel_token: CancellationToken,
join_handle: tokio::task::JoinHandle<()>,
_endpoint_name: String, _endpoint_name: String,
} }
...@@ -145,7 +146,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer { ...@@ -145,7 +146,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
// Spawn task to handle this endpoint using PushEndpoint // Spawn task to handle this endpoint using PushEndpoint
// Note: PushEndpoint::start() is a blocking loop that runs until cancelled // Note: PushEndpoint::start() is a blocking loop that runs until cancelled
let endpoint_name_clone = endpoint_name.clone(); let endpoint_name_clone = endpoint_name.clone();
tokio::spawn(async move { let join_handle = tokio::spawn(async move {
if let Err(e) = push_endpoint if let Err(e) = push_endpoint
.start( .start(
service_endpoint, service_endpoint,
...@@ -180,6 +181,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer { ...@@ -180,6 +181,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
endpoint_name.clone(), endpoint_name.clone(),
EndpointTask { EndpointTask {
cancel_token: endpoint_cancel, cancel_token: endpoint_cancel,
join_handle,
_endpoint_name: endpoint_name, _endpoint_name: endpoint_name,
}, },
); );
...@@ -193,7 +195,25 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer { ...@@ -193,7 +195,25 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
endpoint_name = %endpoint_name, endpoint_name = %endpoint_name,
"Unregistering NATS endpoint" "Unregistering NATS endpoint"
); );
// Cancel the token to trigger graceful shutdown
task.cancel_token.cancel(); task.cancel_token.cancel();
// Wait for the endpoint task to complete (which includes waiting for inflight requests)
tracing::debug!(
endpoint_name = %endpoint_name,
"Waiting for NATS endpoint task to complete"
);
if let Err(e) = task.join_handle.await {
tracing::warn!(
endpoint_name = %endpoint_name,
error = %e,
"NATS endpoint task panicked during shutdown"
);
}
tracing::info!(
endpoint_name = %endpoint_name,
"NATS endpoint unregistration complete"
);
} }
Ok(()) Ok(())
} }
......
...@@ -135,16 +135,26 @@ impl PushEndpoint { ...@@ -135,16 +135,26 @@ impl PushEndpoint {
// await for all inflight requests to complete if graceful shutdown // await for all inflight requests to complete if graceful shutdown
if self.graceful_shutdown { if self.graceful_shutdown {
tracing::info!( let inflight_count = inflight.load(Ordering::SeqCst);
"Waiting for {} inflight requests to complete", if inflight_count > 0 {
inflight.load(Ordering::SeqCst) tracing::info!(
); endpoint_name = endpoint_name_local.as_str(),
while inflight.load(Ordering::SeqCst) > 0 { inflight_count = inflight_count,
notify.notified().await; "Waiting for inflight NATS requests to complete"
);
while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
}
tracing::info!(
endpoint_name = endpoint_name_local.as_str(),
"All inflight NATS requests completed"
);
} }
tracing::info!("All inflight requests completed");
} else { } else {
tracing::info!("Skipping graceful shutdown, not waiting for inflight requests"); tracing::info!(
endpoint_name = endpoint_name_local.as_str(),
"Skipping graceful shutdown, not waiting for inflight requests"
);
} }
Ok(()) Ok(())
......
...@@ -100,11 +100,33 @@ impl SharedTcpServer { ...@@ -100,11 +100,33 @@ impl SharedTcpServer {
} }
pub async fn unregister_endpoint(&self, endpoint_path: &str, endpoint_name: &str) { pub async fn unregister_endpoint(&self, endpoint_path: &str, endpoint_name: &str) {
self.handlers.remove(endpoint_path); if let Some((_, handler)) = self.handlers.remove(endpoint_path) {
tracing::info!( handler
"Unregistered endpoint '{}' from shared TCP server", .system_health
endpoint_name .lock()
); .set_endpoint_health_status(endpoint_name, crate::HealthStatus::NotReady);
tracing::info!(
endpoint_name = %endpoint_name,
endpoint_path = %endpoint_path,
"Unregistered TCP endpoint handler"
);
let inflight_count = handler.inflight.load(Ordering::SeqCst);
if inflight_count > 0 {
tracing::info!(
endpoint_name = %endpoint_name,
inflight_count = inflight_count,
"Waiting for inflight TCP requests to complete"
);
while handler.inflight.load(Ordering::SeqCst) > 0 {
handler.notify.notified().await;
}
tracing::info!(
endpoint_name = %endpoint_name,
"All inflight TCP requests completed"
);
}
}
} }
pub async fn start(self: Arc<Self>) -> Result<()> { pub async fn start(self: Arc<Self>) -> Result<()> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment