// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 //! The `disconnect` module provides a mechanism for our axum http services to monitoring and responding //! to disconnects from the client. //! //! There are two potential phases in any request where we need to handle the disconnect. //! //! For unary, request-response, there is just a single phase where the primary task that axum kicks off //! to handle the request will be dropped if the client disconnects. In order for us to have a long running //! task, like an LLM request, we need to spawn our long running task in a separate task and then spawn //! a second task that will monitor for disconnects from the client. The primary task which spawned the //! two tasks will hold an "armed" [`ConnectionHandle`] which will issue a [`ConnectionStatus::ClosedUnexpectedly`] //! if the task is dropped before it is [`ConnectionHandle::disarm`]ed. //! //! For the streaming case, request in - stream out, we need a second [`ConnectionHandle`] which will be owned //! by the stream. A streaming response is when the [`axum::response::Response]] is a [axum::response::Sse] stream. //! This means the primary task handle will go out of scope when it returns the stream. When we create our //! SSE stream, we capture the second [`ConnectionHandle`] and arm it. If the stream closes gracefully, the //! second handle will be disarmed, otherwise, the stream was dropped and the [`Drop`] trait on the [`ConnectionHandle`] //! triggers a [`ConnectionStatus::ClosedUnexpectedly`] signal. //! //! The [`ConnectionHandle`] is a simple wrapper around a [`tokio::sync::oneshot::Sender`] which will send a //! [`ConnectionStatus`] enum to the primary task. The primary task will then use this to determine if it should //! cancel the request or not. //! //! The [`ConnectionHandle`] is also used to signal to the client that the request has been cancelled. This is //! done by sending a [`axum::response::sse::Event`] with the event type "error" and the data "[DONE]". //! use axum::response::sse::Event; use dynamo_runtime::engine::AsyncEngineContext; use futures::{Stream, StreamExt}; use std::sync::Arc; use crate::http::service::metrics::{CancellationLabels, ErrorType, InflightGuard, Metrics}; #[derive(Clone, Copy)] pub enum ConnectionStatus { Disabled, ClosedUnexpectedly, ClosedGracefully, } pub struct ConnectionHandle { sender: Option>, on_drop: ConnectionStatus, } impl ConnectionHandle { /// Handle which by default will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped. pub fn create_disarmed(sender: tokio::sync::oneshot::Sender) -> Self { Self { sender: Some(sender), on_drop: ConnectionStatus::ClosedGracefully, } } /// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped. pub fn create_armed(sender: tokio::sync::oneshot::Sender) -> Self { Self { sender: Some(sender), on_drop: ConnectionStatus::ClosedUnexpectedly, } } /// Handle which will not issue a signal when dropped. pub fn create_disabled(sender: tokio::sync::oneshot::Sender) -> Self { Self { sender: Some(sender), on_drop: ConnectionStatus::Disabled, } } /// Handle which will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped. pub fn disarm(&mut self) { self.on_drop = ConnectionStatus::ClosedGracefully; } /// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped. pub fn arm(&mut self) { self.on_drop = ConnectionStatus::ClosedUnexpectedly; } } impl Drop for ConnectionHandle { fn drop(&mut self) { if let Some(sender) = self.sender.take() { let _ = sender.send(self.on_drop); } } } /// Creates a pair of handles which will monitor for disconnects from the client. /// /// The first handle is armed and will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped. /// The second handle is disarmed and will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped. /// /// The handles are returned in the order of the first being armed and the second being disarmed. pub async fn create_connection_monitor( engine_context: Arc, metrics: Option>, cancellation_labels: CancellationLabels, ) -> (ConnectionHandle, ConnectionHandle) { // these oneshot channels monitor possible disconnects from the client in two different scopes: // - the local task (connection_handle) // - an optionally streaming response (stream_handle) let (connection_tx, connection_rx) = tokio::sync::oneshot::channel(); let (stream_tx, stream_rx) = tokio::sync::oneshot::channel(); // detached task that will naturally close when both handles are dropped tokio::spawn(connection_monitor( engine_context.clone(), connection_rx, stream_rx, metrics, cancellation_labels, )); // Two handles, the first is armed, the second is disarmed ( ConnectionHandle::create_armed(connection_tx), ConnectionHandle::create_disabled(stream_tx), ) } #[tracing::instrument(level = "trace", skip_all, fields(request_id = %engine_context.id()))] async fn connection_monitor( engine_context: Arc, connection_rx: tokio::sync::oneshot::Receiver, stream_rx: tokio::sync::oneshot::Receiver, metrics: Option>, cancellation_labels: CancellationLabels, ) { match connection_rx.await { Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => { // the client has disconnected, no need to gracefully cancel, just kill the context tracing::trace!("Connection closed unexpectedly; issuing cancellation"); if let Some(metrics) = &metrics { metrics.inc_client_disconnect(); metrics.inc_cancellation(&cancellation_labels); } engine_context.kill(); } Ok(ConnectionStatus::ClosedGracefully) => { tracing::trace!("Connection closed gracefully"); } Ok(ConnectionStatus::Disabled) => {} } match stream_rx.await { Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => { tracing::trace!("Stream closed unexpectedly; issuing cancellation"); if let Some(metrics) = &metrics { metrics.inc_client_disconnect(); metrics.inc_cancellation(&cancellation_labels); } engine_context.kill(); } Ok(ConnectionStatus::ClosedGracefully) => { tracing::trace!("Stream closed gracefully"); } Ok(ConnectionStatus::Disabled) => {} } } /// This method will consume a stream of SSE events and monitor for disconnects or context cancellation. /// /// Uses `tokio::select!` to choose between receiving events from the source stream or detecting when /// the context is stopped. If the context is stopped, we break the stream. If the source stream ends /// naturally, we mark the request as successful and send the final `[DONE]` event. pub fn monitor_for_disconnects( stream: impl Stream>, context: Arc, mut inflight_guard: InflightGuard, mut stream_handle: ConnectionHandle, ) -> impl Stream> { stream_handle.arm(); // Default to Cancelled: if the stream is dropped unexpectedly (e.g. client // disconnect causing a broken-pipe on the SSE write), the guard will report // "cancelled" instead of "internal". The happy path overrides this via mark_ok(). inflight_guard.mark_error(ErrorType::Cancelled); async_stream::try_stream! { tokio::pin!(stream); loop { tokio::select! { event = stream.next() => { match event { Some(Ok(event)) => { yield event; } Some(Err(err)) => { // Mark error as internal since it's a streaming error inflight_guard.mark_error(ErrorType::Internal); yield Event::default().event("error").comment(err.to_string()); // Break to prevent any subsequent mark_ok() from overwriting the error break; } None => { // Stream ended normally inflight_guard.mark_ok(); stream_handle.disarm(); // todo: if we yield a dynamo sentinel event, we need to do it before the done or the // async-openai client will chomp it. yield Event::default().data("[DONE]"); break; } } } _ = context.stopped() => { tracing::trace!("Context stopped; breaking stream"); // Mark as cancelled when context is stopped (client disconnect or timeout) inflight_guard.mark_error(ErrorType::Cancelled); break; } } } } }