Unverified Commit 343a4814 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: http disconnects (#2014)

parent e330d969
......@@ -20,6 +20,7 @@
mod openai;
pub mod disconnect;
pub mod error;
pub mod health;
pub mod metrics;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! The `disconnect` module provides a mechanism for our axum http services to monitoring and responding
//! to disconnects from the client.
//!
//! There are two potential phases in any request where we need to handle the disconnect.
//!
//! For unary, request-response, there is just a single phase where the primary task that axum kicks off
//! to handle the request will be dropped if the client disconnects. In order for us to have a long running
//! task, like an LLM request, we need to spawn our long running task in a separate task and then spawn
//! a second task that will monitor for disconnects from the client. The primary task which spawned the
//! two tasks will hold an "armed" [`ConnectionHandle`] which will issue a [`ConnectionStatus::ClosedUnexpectedly`]
//! if the task is dropped before it is [`ConnectionHandle::disarm`]ed.
//!
//! For the streaming case, request in - stream out, we need a second [`ConnectionHandle`] which will be owned
//! by the stream. A streaming response is when the [`axum::response::Response]] is a [axum::response::Sse] stream.
//! This means the primary task handle will go out of scope when it returns the stream. When we create our
//! SSE stream, we capture the second [`ConnectionHandle`] and arm it. If the stream closes gracefully, the
//! second handle will be disarmed, otherwise, the stream was dropped and the [`Drop`] trait on the [`ConnectionHandle`]
//! triggers a [`ConnectionStatus::ClosedUnexpectedly`] signal.
//!
//! The [`ConnectionHandle`] is a simple wrapper around a [`tokio::sync::oneshot::Sender`] which will send a
//! [`ConnectionStatus`] enum to the primary task. The primary task will then use this to determine if it should
//! cancel the request or not.
//!
//! The [`ConnectionHandle`] is also used to signal to the client that the request has been cancelled. This is
//! done by sending a [`axum::response::sse::Event`] with the event type "error" and the data "[DONE]".
//!
use axum::response::sse::Event;
use dynamo_runtime::engine::AsyncEngineContext;
use futures::{Stream, StreamExt};
use std::sync::Arc;
use crate::http::service::metrics::InflightGuard;
#[derive(Clone, Copy)]
pub enum ConnectionStatus {
Disabled,
ClosedUnexpectedly,
ClosedGracefully,
}
pub struct ConnectionHandle {
sender: Option<tokio::sync::oneshot::Sender<ConnectionStatus>>,
on_drop: ConnectionStatus,
}
impl ConnectionHandle {
/// Handle which by default will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
pub fn create_disarmed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::ClosedGracefully,
}
}
/// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
pub fn create_armed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::ClosedUnexpectedly,
}
}
/// Handle which will not issue a signal when dropped.
pub fn create_disabled(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::Disabled,
}
}
/// Handle which will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
pub fn disarm(&mut self) {
self.on_drop = ConnectionStatus::ClosedGracefully;
}
/// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
pub fn arm(&mut self) {
self.on_drop = ConnectionStatus::ClosedUnexpectedly;
}
}
impl Drop for ConnectionHandle {
fn drop(&mut self) {
if let Some(sender) = self.sender.take() {
let _ = sender.send(self.on_drop);
}
}
}
/// Creates a pair of handles which will monitor for disconnects from the client.
///
/// The first handle is armed and will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
/// The second handle is disarmed and will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
///
/// The handles are returned in the order of the first being armed and the second being disarmed.
pub async fn create_connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>,
) -> (ConnectionHandle, ConnectionHandle) {
// these oneshot channels monitor possible disconnects from the client in two different scopes:
// - the local task (connection_handle)
// - an optionally streaming response (stream_handle)
let (connection_tx, connection_rx) = tokio::sync::oneshot::channel();
let (stream_tx, stream_rx) = tokio::sync::oneshot::channel();
// detached task that will naturally close when both handles are dropped
tokio::spawn(connection_monitor(
engine_context.clone(),
connection_rx,
stream_rx,
));
// Two handles, the first is armed, the second is disarmed
(
ConnectionHandle::create_armed(connection_tx),
ConnectionHandle::create_disabled(stream_tx),
)
}
#[tracing::instrument(level = "trace", skip_all, fields(request_id = %engine_context.id()))]
async fn connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>,
connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
) {
match connection_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
// the client has disconnected, no need to gracefully cancel, just kill the context
tracing::trace!("Connection closed unexpectedly; issuing cancellation");
engine_context.kill();
}
Ok(ConnectionStatus::ClosedGracefully) => {
tracing::trace!("Connection closed gracefully");
}
Ok(ConnectionStatus::Disabled) => {}
}
match stream_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
tracing::trace!("Stream closed unexpectedly; issuing cancellation");
engine_context.kill();
}
Ok(ConnectionStatus::ClosedGracefully) => {
tracing::trace!("Stream closed gracefully");
}
Ok(ConnectionStatus::Disabled) => {}
}
}
/// This method will consume a stream of SSE events and monitor for disconnects or context cancellation.
///
/// Uses `tokio::select!` to choose between receiving events from the source stream or detecting when
/// the context is stopped. If the context is stopped, we break the stream. If the source stream ends
/// naturally, we mark the request as successful and send the final `[DONE]` event.
pub fn monitor_for_disconnects(
stream: impl Stream<Item = Result<Event, axum::Error>>,
context: Arc<dyn AsyncEngineContext>,
mut inflight_guard: InflightGuard,
mut stream_handle: ConnectionHandle,
) -> impl Stream<Item = Result<Event, axum::Error>> {
stream_handle.arm();
async_stream::try_stream! {
tokio::pin!(stream);
loop {
tokio::select! {
event = stream.next() => {
match event {
Some(Ok(event)) => {
yield event;
}
Some(Err(err)) => {
yield Event::default().event("error").comment(err.to_string());
}
None => {
// Stream ended normally
inflight_guard.mark_ok();
stream_handle.disarm();
// todo: if we yield a dynamo sentinel event, we need to do it before the done or the
// async-openai client will chomp it.
yield Event::default().data("[DONE]");
break;
}
}
}
_ = context.stopped() => {
tracing::trace!("Context stopped; breaking stream");
break;
}
}
}
}
}
This diff is collapsed.
......@@ -19,7 +19,7 @@
//! both publicly via the HTTP API and internally between Dynamo components.
//!
use futures::StreamExt;
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
pub mod codec;
......@@ -49,12 +49,12 @@ pub trait ContentProvider {
/// Converts of a stream of [codec::Message]s into a stream of [Annotated]s.
pub fn convert_sse_stream<R>(
stream: DataStream<Result<codec::Message, codec::SseCodecError>>,
) -> DataStream<Annotated<R>>
stream: impl Stream<Item = Result<codec::Message, codec::SseCodecError>>,
) -> impl Stream<Item = Annotated<R>>
where
R: for<'de> Deserialize<'de> + Serialize,
{
let stream = stream.map(|message| match message {
stream.map(|message| match message {
Ok(message) => {
let delta = Annotated::<R>::try_from(message);
match delta {
......@@ -63,6 +63,5 @@ where
}
}
Err(e) => Annotated::from_error(e.to_string()),
});
Box::pin(stream)
})
}
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::StreamExt;
use futures::{Stream, StreamExt};
use std::collections::HashMap;
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
......@@ -22,7 +22,6 @@ use crate::protocols::{
convert_sse_stream, Annotated,
};
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
use dynamo_runtime::engine::DataStream;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
......@@ -95,7 +94,7 @@ impl DeltaAggregator {
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing.
pub async fn apply(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
......@@ -260,7 +259,7 @@ impl NvCreateChatCompletionResponse {
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await
}
......
......@@ -16,7 +16,7 @@
use std::collections::HashMap;
use anyhow::Result;
use futures::StreamExt;
use futures::{Stream, StreamExt};
use super::NvCreateCompletionResponse;
use crate::protocols::{
......@@ -64,7 +64,7 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub async fn apply(
stream: DataStream<Annotated<NvCreateCompletionResponse>>,
stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
) -> Result<NvCreateCompletionResponse> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
......@@ -183,7 +183,7 @@ impl NvCreateCompletionResponse {
}
pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateCompletionResponse>>,
stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
) -> Result<NvCreateCompletionResponse> {
DeltaAggregator::apply(stream).await
}
......
......@@ -20,7 +20,7 @@ use crate::protocols::{
};
use dynamo_runtime::engine::DataStream;
use futures::StreamExt;
use futures::{Stream, StreamExt};
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
......@@ -58,7 +58,7 @@ impl DeltaAggregator {
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing.
pub async fn apply(
stream: DataStream<Annotated<NvCreateEmbeddingResponse>>,
stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
) -> Result<NvCreateEmbeddingResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
......@@ -133,7 +133,7 @@ impl NvCreateEmbeddingResponse {
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateEmbeddingResponse>>,
stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
) -> Result<NvCreateEmbeddingResponse, String> {
DeltaAggregator::apply(stream).await
}
......
......@@ -28,6 +28,8 @@ use dynamo_llm::http::{
},
};
use dynamo_llm::protocols::{
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......@@ -45,11 +47,31 @@ use futures::StreamExt;
use prometheus::{proto::MetricType, Registry};
use reqwest::StatusCode;
use rstest::*;
use std::sync::Arc;
use std::{io::Cursor, sync::Arc};
use tokio::time::timeout;
use tokio_util::codec::FramedRead;
struct CounterEngine {}
#[allow(deprecated)]
// Add a new long-running test engine
struct LongRunningEngine {
delay_ms: u64,
cancelled: Arc<std::sync::atomic::AtomicBool>,
}
impl LongRunningEngine {
fn new(delay_ms: u64) -> Self {
Self {
delay_ms,
cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
fn was_cancelled(&self) -> bool {
self.cancelled.load(std::sync::atomic::Ordering::Acquire)
}
}
#[async_trait]
impl
AsyncEngine<
......@@ -66,6 +88,7 @@ impl
let ctx = context.context();
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
#[allow(deprecated)]
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
// let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone());
......@@ -88,6 +111,54 @@ impl
}
}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for LongRunningEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (_request, context) = request.transfer(());
let ctx = context.context();
tracing::info!(
"LongRunningEngine: Starting generation with {}ms delay",
self.delay_ms
);
let cancelled_flag = self.cancelled.clone();
let delay_ms = self.delay_ms;
let ctx_clone = ctx.clone();
let stream = async_stream::stream! {
// the stream can be dropped or it can be cancelled
// either way we consider this a cancellation
cancelled_flag.store(true, std::sync::atomic::Ordering::SeqCst);
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_millis(delay_ms)) => {
// the stream went to completion
cancelled_flag.store(false, std::sync::atomic::Ordering::SeqCst);
}
_ = ctx_clone.stopped() => {
cancelled_flag.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
yield Annotated::<NvCreateChatCompletionStreamResponse>::from_annotation("event.dynamo.test.sentinel", &"DONE".to_string()).expect("Failed to create annotated response");
};
Ok(ResponseStream::new(Box::pin(stream), ctx))
}
}
struct AlwaysFailEngine {}
#[async_trait]
......@@ -880,3 +951,311 @@ async fn test_generic_byot_client(
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[rstest]
#[tokio::test]
async fn test_client_disconnect_cancellation_unary() {
let service = HttpService::builder().port(8993).build().unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(8993).await;
// Create a long-running engine (10 seconds)
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model("slow-model", long_running_engine.clone())
.unwrap();
let client = reqwest::Client::new();
let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"This will take a long time".to_string(),
),
name: None,
},
);
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("slow-model")
.messages(vec![message])
.stream(false) // Test unary response
.build()
.expect("Failed to build request");
// Start the request and cancel it after 1 second
let start_time = std::time::Instant::now();
let request_future = async {
client
.post("http://localhost:8993/v1/chat/completions")
.json(&request)
.send()
.await
};
// Use timeout to simulate client disconnect after 1 second
let result = timeout(std::time::Duration::from_millis(1000), request_future).await;
let elapsed = start_time.elapsed();
// The request should timeout (simulating client disconnect)
assert!(result.is_err(), "Request should have timed out");
// Give the service a moment to detect the disconnect and propagate cancellation
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
// Verify the engine was cancelled
assert!(
long_running_engine.was_cancelled(),
"Engine should have been cancelled due to client disconnect"
);
// Verify cancellation happened quickly (within 2 seconds, not the full 10 seconds)
assert!(
elapsed < std::time::Duration::from_secs(2),
"Cancellation should have propagated quickly, took {:?}",
elapsed
);
tracing::info!(
"✅ Client disconnect test passed! Request cancelled in {:?}, engine detected cancellation",
elapsed
);
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[rstest]
#[tokio::test]
async fn test_client_disconnect_cancellation_streaming() {
dynamo_runtime::logging::init();
let service = HttpService::builder().port(8994).build().unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(8994).await;
// Create a long-running engine (10 seconds)
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model("slow-stream-model", long_running_engine.clone())
.unwrap();
let client = reqwest::Client::new();
let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"This will stream for a long time".to_string(),
),
name: None,
},
);
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("slow-stream-model")
.messages(vec![message])
.stream(true) // Test streaming response
.build()
.expect("Failed to build request");
// Start the request and cancel it after 1 second
let start_time = std::time::Instant::now();
let request_future = async {
let response = client
.post("http://localhost:8994/v1/chat/completions")
.json(&request)
.send()
.await
.unwrap();
// Start reading the stream, then drop it to simulate client disconnect
let mut stream = response.bytes_stream();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
// Read one chunk then drop the stream (simulating client disconnect)
let _ = StreamExt::next(&mut stream).await;
// Stream gets dropped here when function exits
};
// Use timeout to simulate the streaming request timing out
let _result = timeout(std::time::Duration::from_millis(1500), request_future).await;
let elapsed = start_time.elapsed();
// Give the service time to detect the disconnect
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
// Verify the engine was cancelled
assert!(
long_running_engine.was_cancelled(),
"Engine should have been cancelled due to streaming client disconnect"
);
// Verify cancellation happened reasonably quickly
assert!(
elapsed < std::time::Duration::from_secs(3),
"Stream cancellation should have propagated reasonably quickly, took {:?}",
elapsed
);
tracing::info!(
"✅ Streaming client disconnect test passed! Stream cancelled in {:?}, engine detected cancellation",
elapsed
);
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[rstest]
#[tokio::test]
async fn test_request_id_annotation() {
// TODO(ryan): make better fixtures, this is too much to test sometime so simple
dynamo_runtime::logging::init();
let service = HttpService::builder().port(8995).build().unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(8995).await;
// Add a counter engine for this test
let counter_engine = Arc::new(CounterEngine {});
manager
.add_chat_completions_model("test-model", counter_engine)
.unwrap();
// Create reqwest client directly
let client = reqwest::Client::new();
// Generate a UUID for the request ID
let request_uuid = uuid::Uuid::new_v4();
// Create the request JSON directly
let request_json = serde_json::json!({
"model": "test-model",
"messages": [
{
"role": "user",
"content": "Test request with annotation"
}
],
"stream": true,
"max_tokens": 50,
"nvext": {
"annotations": ["request_id"]
}
});
// Make the streaming request with custom header
let response = client
.post("http://localhost:8995/v1/chat/completions")
.header("x-dynamo-request-id", request_uuid.to_string())
.json(&request_json)
.send()
.await
.expect("Request should succeed");
assert!(
response.status().is_success(),
"Response should be successful"
);
// Collect the entire response body as bytes first
let body_bytes = response
.bytes()
.await
.expect("Failed to read response body");
let body_text = String::from_utf8_lossy(&body_bytes);
// Create a cursor from the text and use SseLineCodec to parse it
let cursor = Cursor::new(body_text.to_string());
let framed = FramedRead::new(cursor, SseLineCodec::new());
let annotated_stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(framed);
// Look for the annotation in the stream
let mut found_request_id_annotation = false;
let mut received_request_id = None;
// Process the annotated stream and look for the request_id annotation
let mut annotated_stream = std::pin::pin!(annotated_stream);
while let Some(annotated_response) = annotated_stream.next().await {
// Check if this is a request_id annotation
if let Some(event) = &annotated_response.event {
if event == "request_id" {
found_request_id_annotation = true;
// Extract the request ID from the annotation
if let Some(comments) = &annotated_response.comment {
if let Some(comment) = comments.first() {
// The comment contains a JSON-encoded string, so we need to parse it
if let Ok(parsed_value) = serde_json::from_str::<String>(comment) {
received_request_id = Some(parsed_value);
} else {
// Fallback: remove quotes manually if JSON parsing fails
received_request_id = Some(comment.trim_matches('"').to_string());
}
}
}
break;
}
}
}
// Verify we found the annotation
assert!(
found_request_id_annotation,
"Should have received request_id annotation in the stream"
);
// Verify the request ID matches what we sent
assert!(
received_request_id.is_some(),
"Should have received the request ID in the annotation"
);
let received_uuid_str = received_request_id.unwrap();
assert_eq!(
received_uuid_str,
request_uuid.to_string(),
"Received request ID should match the one we sent: expected {}, got {}",
request_uuid,
received_uuid_str
);
tracing::info!(
"✅ Request ID annotation test passed! Sent UUID: {}, Received: {}",
request_uuid,
received_uuid_str
);
cancel_token.cancel();
task.await.unwrap().unwrap();
}
......@@ -75,10 +75,16 @@ impl<T: Send + Sync + 'static> Context<T> {
}
}
/// Get the id of the context
pub fn id(&self) -> &str {
self.controller.id()
}
/// Get the content of the context
pub fn content(&self) -> &T {
&self.current
}
pub fn controller(&self) -> &Controller {
&self.controller
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment