Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
343a4814
Unverified
Commit
343a4814
authored
Jul 18, 2025
by
Ryan Olson
Committed by
GitHub
Jul 18, 2025
Browse files
feat: http disconnects (#2014)
parent
e330d969
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
888 additions
and
189 deletions
+888
-189
lib/llm/src/http/service.rs
lib/llm/src/http/service.rs
+1
-0
lib/llm/src/http/service/disconnect.rs
lib/llm/src/http/service/disconnect.rs
+196
-0
lib/llm/src/http/service/openai.rs
lib/llm/src/http/service/openai.rs
+290
-171
lib/llm/src/protocols.rs
lib/llm/src/protocols.rs
+5
-6
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
+3
-4
lib/llm/src/protocols/openai/completions/aggregator.rs
lib/llm/src/protocols/openai/completions/aggregator.rs
+3
-3
lib/llm/src/protocols/openai/embeddings/aggregator.rs
lib/llm/src/protocols/openai/embeddings/aggregator.rs
+3
-3
lib/llm/tests/http-service.rs
lib/llm/tests/http-service.rs
+381
-2
lib/runtime/src/pipeline/context.rs
lib/runtime/src/pipeline/context.rs
+6
-0
No files found.
lib/llm/src/http/service.rs
View file @
343a4814
...
...
@@ -20,6 +20,7 @@
mod
openai
;
pub
mod
disconnect
;
pub
mod
error
;
pub
mod
health
;
pub
mod
metrics
;
...
...
lib/llm/src/http/service/disconnect.rs
0 → 100644
View file @
343a4814
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! The `disconnect` module provides a mechanism for our axum http services to monitoring and responding
//! to disconnects from the client.
//!
//! There are two potential phases in any request where we need to handle the disconnect.
//!
//! For unary, request-response, there is just a single phase where the primary task that axum kicks off
//! to handle the request will be dropped if the client disconnects. In order for us to have a long running
//! task, like an LLM request, we need to spawn our long running task in a separate task and then spawn
//! a second task that will monitor for disconnects from the client. The primary task which spawned the
//! two tasks will hold an "armed" [`ConnectionHandle`] which will issue a [`ConnectionStatus::ClosedUnexpectedly`]
//! if the task is dropped before it is [`ConnectionHandle::disarm`]ed.
//!
//! For the streaming case, request in - stream out, we need a second [`ConnectionHandle`] which will be owned
//! by the stream. A streaming response is when the [`axum::response::Response]] is a [axum::response::Sse] stream.
//! This means the primary task handle will go out of scope when it returns the stream. When we create our
//! SSE stream, we capture the second [`ConnectionHandle`] and arm it. If the stream closes gracefully, the
//! second handle will be disarmed, otherwise, the stream was dropped and the [`Drop`] trait on the [`ConnectionHandle`]
//! triggers a [`ConnectionStatus::ClosedUnexpectedly`] signal.
//!
//! The [`ConnectionHandle`] is a simple wrapper around a [`tokio::sync::oneshot::Sender`] which will send a
//! [`ConnectionStatus`] enum to the primary task. The primary task will then use this to determine if it should
//! cancel the request or not.
//!
//! The [`ConnectionHandle`] is also used to signal to the client that the request has been cancelled. This is
//! done by sending a [`axum::response::sse::Event`] with the event type "error" and the data "[DONE]".
//!
use
axum
::
response
::
sse
::
Event
;
use
dynamo_runtime
::
engine
::
AsyncEngineContext
;
use
futures
::{
Stream
,
StreamExt
};
use
std
::
sync
::
Arc
;
use
crate
::
http
::
service
::
metrics
::
InflightGuard
;
#[derive(Clone,
Copy)]
pub
enum
ConnectionStatus
{
Disabled
,
ClosedUnexpectedly
,
ClosedGracefully
,
}
pub
struct
ConnectionHandle
{
sender
:
Option
<
tokio
::
sync
::
oneshot
::
Sender
<
ConnectionStatus
>>
,
on_drop
:
ConnectionStatus
,
}
impl
ConnectionHandle
{
/// Handle which by default will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
pub
fn
create_disarmed
(
sender
:
tokio
::
sync
::
oneshot
::
Sender
<
ConnectionStatus
>
)
->
Self
{
Self
{
sender
:
Some
(
sender
),
on_drop
:
ConnectionStatus
::
ClosedGracefully
,
}
}
/// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
pub
fn
create_armed
(
sender
:
tokio
::
sync
::
oneshot
::
Sender
<
ConnectionStatus
>
)
->
Self
{
Self
{
sender
:
Some
(
sender
),
on_drop
:
ConnectionStatus
::
ClosedUnexpectedly
,
}
}
/// Handle which will not issue a signal when dropped.
pub
fn
create_disabled
(
sender
:
tokio
::
sync
::
oneshot
::
Sender
<
ConnectionStatus
>
)
->
Self
{
Self
{
sender
:
Some
(
sender
),
on_drop
:
ConnectionStatus
::
Disabled
,
}
}
/// Handle which will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
pub
fn
disarm
(
&
mut
self
)
{
self
.on_drop
=
ConnectionStatus
::
ClosedGracefully
;
}
/// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
pub
fn
arm
(
&
mut
self
)
{
self
.on_drop
=
ConnectionStatus
::
ClosedUnexpectedly
;
}
}
impl
Drop
for
ConnectionHandle
{
fn
drop
(
&
mut
self
)
{
if
let
Some
(
sender
)
=
self
.sender
.take
()
{
let
_
=
sender
.send
(
self
.on_drop
);
}
}
}
/// Creates a pair of handles which will monitor for disconnects from the client.
///
/// The first handle is armed and will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
/// The second handle is disarmed and will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
///
/// The handles are returned in the order of the first being armed and the second being disarmed.
pub
async
fn
create_connection_monitor
(
engine_context
:
Arc
<
dyn
AsyncEngineContext
>
,
)
->
(
ConnectionHandle
,
ConnectionHandle
)
{
// these oneshot channels monitor possible disconnects from the client in two different scopes:
// - the local task (connection_handle)
// - an optionally streaming response (stream_handle)
let
(
connection_tx
,
connection_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
let
(
stream_tx
,
stream_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
// detached task that will naturally close when both handles are dropped
tokio
::
spawn
(
connection_monitor
(
engine_context
.clone
(),
connection_rx
,
stream_rx
,
));
// Two handles, the first is armed, the second is disarmed
(
ConnectionHandle
::
create_armed
(
connection_tx
),
ConnectionHandle
::
create_disabled
(
stream_tx
),
)
}
#[tracing::instrument(level
=
"trace"
,
skip_all,
fields(request_id
=
%
engine_context
.
id()))]
async
fn
connection_monitor
(
engine_context
:
Arc
<
dyn
AsyncEngineContext
>
,
connection_rx
:
tokio
::
sync
::
oneshot
::
Receiver
<
ConnectionStatus
>
,
stream_rx
:
tokio
::
sync
::
oneshot
::
Receiver
<
ConnectionStatus
>
,
)
{
match
connection_rx
.await
{
Err
(
_
)
|
Ok
(
ConnectionStatus
::
ClosedUnexpectedly
)
=>
{
// the client has disconnected, no need to gracefully cancel, just kill the context
tracing
::
trace!
(
"Connection closed unexpectedly; issuing cancellation"
);
engine_context
.kill
();
}
Ok
(
ConnectionStatus
::
ClosedGracefully
)
=>
{
tracing
::
trace!
(
"Connection closed gracefully"
);
}
Ok
(
ConnectionStatus
::
Disabled
)
=>
{}
}
match
stream_rx
.await
{
Err
(
_
)
|
Ok
(
ConnectionStatus
::
ClosedUnexpectedly
)
=>
{
tracing
::
trace!
(
"Stream closed unexpectedly; issuing cancellation"
);
engine_context
.kill
();
}
Ok
(
ConnectionStatus
::
ClosedGracefully
)
=>
{
tracing
::
trace!
(
"Stream closed gracefully"
);
}
Ok
(
ConnectionStatus
::
Disabled
)
=>
{}
}
}
/// This method will consume a stream of SSE events and monitor for disconnects or context cancellation.
///
/// Uses `tokio::select!` to choose between receiving events from the source stream or detecting when
/// the context is stopped. If the context is stopped, we break the stream. If the source stream ends
/// naturally, we mark the request as successful and send the final `[DONE]` event.
pub
fn
monitor_for_disconnects
(
stream
:
impl
Stream
<
Item
=
Result
<
Event
,
axum
::
Error
>>
,
context
:
Arc
<
dyn
AsyncEngineContext
>
,
mut
inflight_guard
:
InflightGuard
,
mut
stream_handle
:
ConnectionHandle
,
)
->
impl
Stream
<
Item
=
Result
<
Event
,
axum
::
Error
>>
{
stream_handle
.arm
();
async_stream
::
try_stream!
{
tokio
::
pin!
(
stream
);
loop
{
tokio
::
select!
{
event
=
stream
.next
()
=>
{
match
event
{
Some
(
Ok
(
event
))
=>
{
yield
event
;
}
Some
(
Err
(
err
))
=>
{
yield
Event
::
default
()
.event
(
"error"
)
.comment
(
err
.to_string
());
}
None
=>
{
// Stream ended normally
inflight_guard
.mark_ok
();
stream_handle
.disarm
();
// todo: if we yield a dynamo sentinel event, we need to do it before the done or the
// async-openai client will chomp it.
yield
Event
::
default
()
.data
(
"[DONE]"
);
break
;
}
}
}
_
=
context
.stopped
()
=>
{
tracing
::
trace!
(
"Context stopped; breaking stream"
);
break
;
}
}
}
}
}
lib/llm/src/http/service/openai.rs
View file @
343a4814
This diff is collapsed.
Click to expand it.
lib/llm/src/protocols.rs
View file @
343a4814
...
...
@@ -19,7 +19,7 @@
//! both publicly via the HTTP API and internally between Dynamo components.
//!
use
futures
::
StreamExt
;
use
futures
::
{
Stream
,
StreamExt
}
;
use
serde
::{
Deserialize
,
Serialize
};
pub
mod
codec
;
...
...
@@ -49,12 +49,12 @@ pub trait ContentProvider {
/// Converts of a stream of [codec::Message]s into a stream of [Annotated]s.
pub
fn
convert_sse_stream
<
R
>
(
stream
:
Data
Stream
<
Result
<
codec
::
Message
,
codec
::
SseCodecError
>>
,
)
->
Data
Stream
<
Annotated
<
R
>>
stream
:
impl
Stream
<
Item
=
Result
<
codec
::
Message
,
codec
::
SseCodecError
>>
,
)
->
impl
Stream
<
Item
=
Annotated
<
R
>>
where
R
:
for
<
'de
>
Deserialize
<
'de
>
+
Serialize
,
{
let
stream
=
stream
.map
(|
message
|
match
message
{
stream
.map
(|
message
|
match
message
{
Ok
(
message
)
=>
{
let
delta
=
Annotated
::
<
R
>
::
try_from
(
message
);
match
delta
{
...
...
@@ -63,6 +63,5 @@ where
}
}
Err
(
e
)
=>
Annotated
::
from_error
(
e
.to_string
()),
});
Box
::
pin
(
stream
)
})
}
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
View file @
343a4814
...
...
@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
futures
::
StreamExt
;
use
futures
::
{
Stream
,
StreamExt
}
;
use
std
::
collections
::
HashMap
;
use
super
::{
NvCreateChatCompletionResponse
,
NvCreateChatCompletionStreamResponse
};
...
...
@@ -22,7 +22,6 @@ use crate::protocols::{
convert_sse_stream
,
Annotated
,
};
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
use
dynamo_runtime
::
engine
::
DataStream
;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
...
...
@@ -95,7 +94,7 @@ impl DeltaAggregator {
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing.
pub
async
fn
apply
(
stream
:
Data
Stream
<
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
...
...
@@ -260,7 +259,7 @@ impl NvCreateChatCompletionResponse {
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub
async
fn
from_annotated_stream
(
stream
:
Data
Stream
<
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
)
->
Result
<
NvCreateChatCompletionResponse
,
String
>
{
DeltaAggregator
::
apply
(
stream
)
.await
}
...
...
lib/llm/src/protocols/openai/completions/aggregator.rs
View file @
343a4814
...
...
@@ -16,7 +16,7 @@
use
std
::
collections
::
HashMap
;
use
anyhow
::
Result
;
use
futures
::
StreamExt
;
use
futures
::
{
Stream
,
StreamExt
}
;
use
super
::
NvCreateCompletionResponse
;
use
crate
::
protocols
::{
...
...
@@ -64,7 +64,7 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub
async
fn
apply
(
stream
:
Data
Stream
<
Annotated
<
NvCreateCompletionResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateCompletionResponse
>>
,
)
->
Result
<
NvCreateCompletionResponse
>
{
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
...
...
@@ -183,7 +183,7 @@ impl NvCreateCompletionResponse {
}
pub
async
fn
from_annotated_stream
(
stream
:
Data
Stream
<
Annotated
<
NvCreateCompletionResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateCompletionResponse
>>
,
)
->
Result
<
NvCreateCompletionResponse
>
{
DeltaAggregator
::
apply
(
stream
)
.await
}
...
...
lib/llm/src/protocols/openai/embeddings/aggregator.rs
View file @
343a4814
...
...
@@ -20,7 +20,7 @@ use crate::protocols::{
};
use
dynamo_runtime
::
engine
::
DataStream
;
use
futures
::
StreamExt
;
use
futures
::
{
Stream
,
StreamExt
}
;
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
...
...
@@ -58,7 +58,7 @@ impl DeltaAggregator {
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing.
pub
async
fn
apply
(
stream
:
Data
Stream
<
Annotated
<
NvCreateEmbeddingResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateEmbeddingResponse
>>
,
)
->
Result
<
NvCreateEmbeddingResponse
,
String
>
{
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
...
...
@@ -133,7 +133,7 @@ impl NvCreateEmbeddingResponse {
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub
async
fn
from_annotated_stream
(
stream
:
Data
Stream
<
Annotated
<
NvCreateEmbeddingResponse
>>
,
stream
:
impl
Stream
<
Item
=
Annotated
<
NvCreateEmbeddingResponse
>>
,
)
->
Result
<
NvCreateEmbeddingResponse
,
String
>
{
DeltaAggregator
::
apply
(
stream
)
.await
}
...
...
lib/llm/tests/http-service.rs
View file @
343a4814
...
...
@@ -28,6 +28,8 @@ use dynamo_llm::http::{
},
};
use
dynamo_llm
::
protocols
::{
codec
::
SseLineCodec
,
convert_sse_stream
,
openai
::{
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
completions
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
},
...
...
@@ -45,11 +47,31 @@ use futures::StreamExt;
use
prometheus
::{
proto
::
MetricType
,
Registry
};
use
reqwest
::
StatusCode
;
use
rstest
::
*
;
use
std
::
sync
::
Arc
;
use
std
::{
io
::
Cursor
,
sync
::
Arc
};
use
tokio
::
time
::
timeout
;
use
tokio_util
::
codec
::
FramedRead
;
struct
CounterEngine
{}
#[allow(deprecated)]
// Add a new long-running test engine
struct
LongRunningEngine
{
delay_ms
:
u64
,
cancelled
:
Arc
<
std
::
sync
::
atomic
::
AtomicBool
>
,
}
impl
LongRunningEngine
{
fn
new
(
delay_ms
:
u64
)
->
Self
{
Self
{
delay_ms
,
cancelled
:
Arc
::
new
(
std
::
sync
::
atomic
::
AtomicBool
::
new
(
false
)),
}
}
fn
was_cancelled
(
&
self
)
->
bool
{
self
.cancelled
.load
(
std
::
sync
::
atomic
::
Ordering
::
Acquire
)
}
}
#[async_trait]
impl
AsyncEngine
<
...
...
@@ -66,6 +88,7 @@ impl
let
ctx
=
context
.context
();
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
#[allow(deprecated)]
let
max_tokens
=
request
.inner.max_tokens
.unwrap_or
(
0
)
as
u64
;
// let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone());
...
...
@@ -88,6 +111,54 @@ impl
}
}
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
NvCreateChatCompletionRequest
>
,
ManyOut
<
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
Error
,
>
for
LongRunningEngine
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
NvCreateChatCompletionRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
Error
>
{
let
(
_
request
,
context
)
=
request
.transfer
(());
let
ctx
=
context
.context
();
tracing
::
info!
(
"LongRunningEngine: Starting generation with {}ms delay"
,
self
.delay_ms
);
let
cancelled_flag
=
self
.cancelled
.clone
();
let
delay_ms
=
self
.delay_ms
;
let
ctx_clone
=
ctx
.clone
();
let
stream
=
async_stream
::
stream!
{
// the stream can be dropped or it can be cancelled
// either way we consider this a cancellation
cancelled_flag
.store
(
true
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
);
tokio
::
select!
{
_
=
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
delay_ms
))
=>
{
// the stream went to completion
cancelled_flag
.store
(
false
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
);
}
_
=
ctx_clone
.stopped
()
=>
{
cancelled_flag
.store
(
true
,
std
::
sync
::
atomic
::
Ordering
::
SeqCst
);
}
}
yield
Annotated
::
<
NvCreateChatCompletionStreamResponse
>
::
from_annotation
(
"event.dynamo.test.sentinel"
,
&
"DONE"
.to_string
())
.expect
(
"Failed to create annotated response"
);
};
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
ctx
))
}
}
struct
AlwaysFailEngine
{}
#[async_trait]
...
...
@@ -880,3 +951,311 @@ async fn test_generic_byot_client(
cancel_token
.cancel
();
task
.await
.unwrap
()
.unwrap
();
}
#[rstest]
#[tokio::test]
async
fn
test_client_disconnect_cancellation_unary
()
{
let
service
=
HttpService
::
builder
()
.port
(
8993
)
.build
()
.unwrap
();
let
state
=
service
.state_clone
();
let
manager
=
state
.manager
();
let
token
=
CancellationToken
::
new
();
let
cancel_token
=
token
.clone
();
// Start the service
let
task
=
tokio
::
spawn
(
async
move
{
service
.run
(
token
)
.await
});
// Wait for service to be ready
wait_for_service_ready
(
8993
)
.await
;
// Create a long-running engine (10 seconds)
let
long_running_engine
=
Arc
::
new
(
LongRunningEngine
::
new
(
10_000
));
manager
.add_chat_completions_model
(
"slow-model"
,
long_running_engine
.clone
())
.unwrap
();
let
client
=
reqwest
::
Client
::
new
();
let
message
=
async_openai
::
types
::
ChatCompletionRequestMessage
::
User
(
async_openai
::
types
::
ChatCompletionRequestUserMessage
{
content
:
async_openai
::
types
::
ChatCompletionRequestUserMessageContent
::
Text
(
"This will take a long time"
.to_string
(),
),
name
:
None
,
},
);
let
request
=
async_openai
::
types
::
CreateChatCompletionRequestArgs
::
default
()
.model
(
"slow-model"
)
.messages
(
vec!
[
message
])
.stream
(
false
)
// Test unary response
.build
()
.expect
(
"Failed to build request"
);
// Start the request and cancel it after 1 second
let
start_time
=
std
::
time
::
Instant
::
now
();
let
request_future
=
async
{
client
.post
(
"http://localhost:8993/v1/chat/completions"
)
.json
(
&
request
)
.send
()
.await
};
// Use timeout to simulate client disconnect after 1 second
let
result
=
timeout
(
std
::
time
::
Duration
::
from_millis
(
1000
),
request_future
)
.await
;
let
elapsed
=
start_time
.elapsed
();
// The request should timeout (simulating client disconnect)
assert
!
(
result
.is_err
(),
"Request should have timed out"
);
// Give the service a moment to detect the disconnect and propagate cancellation
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
500
))
.await
;
// Verify the engine was cancelled
assert
!
(
long_running_engine
.was_cancelled
(),
"Engine should have been cancelled due to client disconnect"
);
// Verify cancellation happened quickly (within 2 seconds, not the full 10 seconds)
assert
!
(
elapsed
<
std
::
time
::
Duration
::
from_secs
(
2
),
"Cancellation should have propagated quickly, took {:?}"
,
elapsed
);
tracing
::
info!
(
"✅ Client disconnect test passed! Request cancelled in {:?}, engine detected cancellation"
,
elapsed
);
cancel_token
.cancel
();
task
.await
.unwrap
()
.unwrap
();
}
#[rstest]
#[tokio::test]
async
fn
test_client_disconnect_cancellation_streaming
()
{
dynamo_runtime
::
logging
::
init
();
let
service
=
HttpService
::
builder
()
.port
(
8994
)
.build
()
.unwrap
();
let
state
=
service
.state_clone
();
let
manager
=
state
.manager
();
let
token
=
CancellationToken
::
new
();
let
cancel_token
=
token
.clone
();
// Start the service
let
task
=
tokio
::
spawn
(
async
move
{
service
.run
(
token
)
.await
});
// Wait for service to be ready
wait_for_service_ready
(
8994
)
.await
;
// Create a long-running engine (10 seconds)
let
long_running_engine
=
Arc
::
new
(
LongRunningEngine
::
new
(
10_000
));
manager
.add_chat_completions_model
(
"slow-stream-model"
,
long_running_engine
.clone
())
.unwrap
();
let
client
=
reqwest
::
Client
::
new
();
let
message
=
async_openai
::
types
::
ChatCompletionRequestMessage
::
User
(
async_openai
::
types
::
ChatCompletionRequestUserMessage
{
content
:
async_openai
::
types
::
ChatCompletionRequestUserMessageContent
::
Text
(
"This will stream for a long time"
.to_string
(),
),
name
:
None
,
},
);
let
request
=
async_openai
::
types
::
CreateChatCompletionRequestArgs
::
default
()
.model
(
"slow-stream-model"
)
.messages
(
vec!
[
message
])
.stream
(
true
)
// Test streaming response
.build
()
.expect
(
"Failed to build request"
);
// Start the request and cancel it after 1 second
let
start_time
=
std
::
time
::
Instant
::
now
();
let
request_future
=
async
{
let
response
=
client
.post
(
"http://localhost:8994/v1/chat/completions"
)
.json
(
&
request
)
.send
()
.await
.unwrap
();
// Start reading the stream, then drop it to simulate client disconnect
let
mut
stream
=
response
.bytes_stream
();
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
500
))
.await
;
// Read one chunk then drop the stream (simulating client disconnect)
let
_
=
StreamExt
::
next
(
&
mut
stream
)
.await
;
// Stream gets dropped here when function exits
};
// Use timeout to simulate the streaming request timing out
let
_
result
=
timeout
(
std
::
time
::
Duration
::
from_millis
(
1500
),
request_future
)
.await
;
let
elapsed
=
start_time
.elapsed
();
// Give the service time to detect the disconnect
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1000
))
.await
;
// Verify the engine was cancelled
assert
!
(
long_running_engine
.was_cancelled
(),
"Engine should have been cancelled due to streaming client disconnect"
);
// Verify cancellation happened reasonably quickly
assert
!
(
elapsed
<
std
::
time
::
Duration
::
from_secs
(
3
),
"Stream cancellation should have propagated reasonably quickly, took {:?}"
,
elapsed
);
tracing
::
info!
(
"✅ Streaming client disconnect test passed! Stream cancelled in {:?}, engine detected cancellation"
,
elapsed
);
cancel_token
.cancel
();
task
.await
.unwrap
()
.unwrap
();
}
#[rstest]
#[tokio::test]
async
fn
test_request_id_annotation
()
{
// TODO(ryan): make better fixtures, this is too much to test sometime so simple
dynamo_runtime
::
logging
::
init
();
let
service
=
HttpService
::
builder
()
.port
(
8995
)
.build
()
.unwrap
();
let
state
=
service
.state_clone
();
let
manager
=
state
.manager
();
let
token
=
CancellationToken
::
new
();
let
cancel_token
=
token
.clone
();
// Start the service
let
task
=
tokio
::
spawn
(
async
move
{
service
.run
(
token
)
.await
});
// Wait for service to be ready
wait_for_service_ready
(
8995
)
.await
;
// Add a counter engine for this test
let
counter_engine
=
Arc
::
new
(
CounterEngine
{});
manager
.add_chat_completions_model
(
"test-model"
,
counter_engine
)
.unwrap
();
// Create reqwest client directly
let
client
=
reqwest
::
Client
::
new
();
// Generate a UUID for the request ID
let
request_uuid
=
uuid
::
Uuid
::
new_v4
();
// Create the request JSON directly
let
request_json
=
serde_json
::
json!
({
"model"
:
"test-model"
,
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
"Test request with annotation"
}
],
"stream"
:
true
,
"max_tokens"
:
50
,
"nvext"
:
{
"annotations"
:
[
"request_id"
]
}
});
// Make the streaming request with custom header
let
response
=
client
.post
(
"http://localhost:8995/v1/chat/completions"
)
.header
(
"x-dynamo-request-id"
,
request_uuid
.to_string
())
.json
(
&
request_json
)
.send
()
.await
.expect
(
"Request should succeed"
);
assert
!
(
response
.status
()
.is_success
(),
"Response should be successful"
);
// Collect the entire response body as bytes first
let
body_bytes
=
response
.bytes
()
.await
.expect
(
"Failed to read response body"
);
let
body_text
=
String
::
from_utf8_lossy
(
&
body_bytes
);
// Create a cursor from the text and use SseLineCodec to parse it
let
cursor
=
Cursor
::
new
(
body_text
.to_string
());
let
framed
=
FramedRead
::
new
(
cursor
,
SseLineCodec
::
new
());
let
annotated_stream
=
convert_sse_stream
::
<
NvCreateChatCompletionStreamResponse
>
(
framed
);
// Look for the annotation in the stream
let
mut
found_request_id_annotation
=
false
;
let
mut
received_request_id
=
None
;
// Process the annotated stream and look for the request_id annotation
let
mut
annotated_stream
=
std
::
pin
::
pin!
(
annotated_stream
);
while
let
Some
(
annotated_response
)
=
annotated_stream
.next
()
.await
{
// Check if this is a request_id annotation
if
let
Some
(
event
)
=
&
annotated_response
.event
{
if
event
==
"request_id"
{
found_request_id_annotation
=
true
;
// Extract the request ID from the annotation
if
let
Some
(
comments
)
=
&
annotated_response
.comment
{
if
let
Some
(
comment
)
=
comments
.first
()
{
// The comment contains a JSON-encoded string, so we need to parse it
if
let
Ok
(
parsed_value
)
=
serde_json
::
from_str
::
<
String
>
(
comment
)
{
received_request_id
=
Some
(
parsed_value
);
}
else
{
// Fallback: remove quotes manually if JSON parsing fails
received_request_id
=
Some
(
comment
.trim_matches
(
'"'
)
.to_string
());
}
}
}
break
;
}
}
}
// Verify we found the annotation
assert
!
(
found_request_id_annotation
,
"Should have received request_id annotation in the stream"
);
// Verify the request ID matches what we sent
assert
!
(
received_request_id
.is_some
(),
"Should have received the request ID in the annotation"
);
let
received_uuid_str
=
received_request_id
.unwrap
();
assert_eq!
(
received_uuid_str
,
request_uuid
.to_string
(),
"Received request ID should match the one we sent: expected {}, got {}"
,
request_uuid
,
received_uuid_str
);
tracing
::
info!
(
"✅ Request ID annotation test passed! Sent UUID: {}, Received: {}"
,
request_uuid
,
received_uuid_str
);
cancel_token
.cancel
();
task
.await
.unwrap
()
.unwrap
();
}
lib/runtime/src/pipeline/context.rs
View file @
343a4814
...
...
@@ -75,10 +75,16 @@ impl<T: Send + Sync + 'static> Context<T> {
}
}
/// Get the id of the context
pub
fn
id
(
&
self
)
->
&
str
{
self
.controller
.id
()
}
/// Get the content of the context
pub
fn
content
(
&
self
)
->
&
T
{
&
self
.current
}
pub
fn
controller
(
&
self
)
->
&
Controller
{
&
self
.controller
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment