"docs/vscode:/vscode.git/clone" did not exist on "38ef94888afc0c2bccc2f18422d2b525d7649ac3"
Unverified Commit 343a4814 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: http disconnects (#2014)

parent e330d969
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
mod openai; mod openai;
pub mod disconnect;
pub mod error; pub mod error;
pub mod health; pub mod health;
pub mod metrics; pub mod metrics;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! The `disconnect` module provides a mechanism for our axum http services to monitoring and responding
//! to disconnects from the client.
//!
//! There are two potential phases in any request where we need to handle the disconnect.
//!
//! For unary, request-response, there is just a single phase where the primary task that axum kicks off
//! to handle the request will be dropped if the client disconnects. In order for us to have a long running
//! task, like an LLM request, we need to spawn our long running task in a separate task and then spawn
//! a second task that will monitor for disconnects from the client. The primary task which spawned the
//! two tasks will hold an "armed" [`ConnectionHandle`] which will issue a [`ConnectionStatus::ClosedUnexpectedly`]
//! if the task is dropped before it is [`ConnectionHandle::disarm`]ed.
//!
//! For the streaming case, request in - stream out, we need a second [`ConnectionHandle`] which will be owned
//! by the stream. A streaming response is when the [`axum::response::Response]] is a [axum::response::Sse] stream.
//! This means the primary task handle will go out of scope when it returns the stream. When we create our
//! SSE stream, we capture the second [`ConnectionHandle`] and arm it. If the stream closes gracefully, the
//! second handle will be disarmed, otherwise, the stream was dropped and the [`Drop`] trait on the [`ConnectionHandle`]
//! triggers a [`ConnectionStatus::ClosedUnexpectedly`] signal.
//!
//! The [`ConnectionHandle`] is a simple wrapper around a [`tokio::sync::oneshot::Sender`] which will send a
//! [`ConnectionStatus`] enum to the primary task. The primary task will then use this to determine if it should
//! cancel the request or not.
//!
//! The [`ConnectionHandle`] is also used to signal to the client that the request has been cancelled. This is
//! done by sending a [`axum::response::sse::Event`] with the event type "error" and the data "[DONE]".
//!
use axum::response::sse::Event;
use dynamo_runtime::engine::AsyncEngineContext;
use futures::{Stream, StreamExt};
use std::sync::Arc;
use crate::http::service::metrics::InflightGuard;
#[derive(Clone, Copy)]
pub enum ConnectionStatus {
Disabled,
ClosedUnexpectedly,
ClosedGracefully,
}
pub struct ConnectionHandle {
sender: Option<tokio::sync::oneshot::Sender<ConnectionStatus>>,
on_drop: ConnectionStatus,
}
impl ConnectionHandle {
/// Handle which by default will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
pub fn create_disarmed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::ClosedGracefully,
}
}
/// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
pub fn create_armed(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::ClosedUnexpectedly,
}
}
/// Handle which will not issue a signal when dropped.
pub fn create_disabled(sender: tokio::sync::oneshot::Sender<ConnectionStatus>) -> Self {
Self {
sender: Some(sender),
on_drop: ConnectionStatus::Disabled,
}
}
/// Handle which will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
pub fn disarm(&mut self) {
self.on_drop = ConnectionStatus::ClosedGracefully;
}
/// Handle which will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
pub fn arm(&mut self) {
self.on_drop = ConnectionStatus::ClosedUnexpectedly;
}
}
impl Drop for ConnectionHandle {
fn drop(&mut self) {
if let Some(sender) = self.sender.take() {
let _ = sender.send(self.on_drop);
}
}
}
/// Creates a pair of handles which will monitor for disconnects from the client.
///
/// The first handle is armed and will issue a [`ConnectionStatus::ClosedUnexpectedly`] signal when dropped.
/// The second handle is disarmed and will issue a [`ConnectionStatus::ClosedGracefully`] signal when dropped.
///
/// The handles are returned in the order of the first being armed and the second being disarmed.
pub async fn create_connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>,
) -> (ConnectionHandle, ConnectionHandle) {
// these oneshot channels monitor possible disconnects from the client in two different scopes:
// - the local task (connection_handle)
// - an optionally streaming response (stream_handle)
let (connection_tx, connection_rx) = tokio::sync::oneshot::channel();
let (stream_tx, stream_rx) = tokio::sync::oneshot::channel();
// detached task that will naturally close when both handles are dropped
tokio::spawn(connection_monitor(
engine_context.clone(),
connection_rx,
stream_rx,
));
// Two handles, the first is armed, the second is disarmed
(
ConnectionHandle::create_armed(connection_tx),
ConnectionHandle::create_disabled(stream_tx),
)
}
#[tracing::instrument(level = "trace", skip_all, fields(request_id = %engine_context.id()))]
async fn connection_monitor(
engine_context: Arc<dyn AsyncEngineContext>,
connection_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
stream_rx: tokio::sync::oneshot::Receiver<ConnectionStatus>,
) {
match connection_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
// the client has disconnected, no need to gracefully cancel, just kill the context
tracing::trace!("Connection closed unexpectedly; issuing cancellation");
engine_context.kill();
}
Ok(ConnectionStatus::ClosedGracefully) => {
tracing::trace!("Connection closed gracefully");
}
Ok(ConnectionStatus::Disabled) => {}
}
match stream_rx.await {
Err(_) | Ok(ConnectionStatus::ClosedUnexpectedly) => {
tracing::trace!("Stream closed unexpectedly; issuing cancellation");
engine_context.kill();
}
Ok(ConnectionStatus::ClosedGracefully) => {
tracing::trace!("Stream closed gracefully");
}
Ok(ConnectionStatus::Disabled) => {}
}
}
/// This method will consume a stream of SSE events and monitor for disconnects or context cancellation.
///
/// Uses `tokio::select!` to choose between receiving events from the source stream or detecting when
/// the context is stopped. If the context is stopped, we break the stream. If the source stream ends
/// naturally, we mark the request as successful and send the final `[DONE]` event.
pub fn monitor_for_disconnects(
stream: impl Stream<Item = Result<Event, axum::Error>>,
context: Arc<dyn AsyncEngineContext>,
mut inflight_guard: InflightGuard,
mut stream_handle: ConnectionHandle,
) -> impl Stream<Item = Result<Event, axum::Error>> {
stream_handle.arm();
async_stream::try_stream! {
tokio::pin!(stream);
loop {
tokio::select! {
event = stream.next() => {
match event {
Some(Ok(event)) => {
yield event;
}
Some(Err(err)) => {
yield Event::default().event("error").comment(err.to_string());
}
None => {
// Stream ended normally
inflight_guard.mark_ok();
stream_handle.disarm();
// todo: if we yield a dynamo sentinel event, we need to do it before the done or the
// async-openai client will chomp it.
yield Event::default().data("[DONE]");
break;
}
}
}
_ = context.stopped() => {
tracing::trace!("Context stopped; breaking stream");
break;
}
}
}
}
}
...@@ -3,14 +3,13 @@ ...@@ -3,14 +3,13 @@
use std::{ use std::{
collections::HashSet, collections::HashSet,
pin::Pin,
sync::Arc, sync::Arc,
time::{SystemTime, UNIX_EPOCH}, time::{SystemTime, UNIX_EPOCH},
}; };
use axum::{ use axum::{
extract::State, extract::State,
http::StatusCode, http::{HeaderMap, StatusCode},
response::{ response::{
sse::{Event, KeepAlive, Sse}, sse::{Event, KeepAlive, Sse},
IntoResponse, Response, IntoResponse, Response,
...@@ -18,14 +17,17 @@ use axum::{ ...@@ -18,14 +17,17 @@ use axum::{
routing::{get, post}, routing::{get, post},
Json, Router, Json, Router,
}; };
use dynamo_runtime::pipeline::{AsyncEngineContext, Context}; use dynamo_runtime::{
use futures::{Stream, StreamExt}; pipeline::{AsyncEngineContextProvider, Context},
protocols::annotated::AnnotationsProvider,
};
use futures::{stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio_stream::wrappers::ReceiverStream;
use super::{ use super::{
disconnect::{create_connection_monitor, monitor_for_disconnects, ConnectionHandle},
error::HttpError, error::HttpError,
metrics::{Endpoint, InflightGuard, ResponseMetricCollector}, metrics::{Endpoint, ResponseMetricCollector},
service_v2, RouteDoc, service_v2, RouteDoc,
}; };
use crate::preprocessor::LLMMetricAnnotation; use crate::preprocessor::LLMMetricAnnotation;
...@@ -38,17 +40,24 @@ use crate::protocols::openai::{ ...@@ -38,17 +40,24 @@ use crate::protocols::openai::{
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::Annotated; use crate::types::Annotated;
pub const DYNAMO_REQUEST_ID_HEADER: &str = "x-dynamo-request-id";
/// Dynamo Annotation for the request ID
pub const ANNOTATION_REQUEST_ID: &str = "request_id";
pub type ErrorResponse = (StatusCode, Json<ErrorMessage>);
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub(crate) struct ErrorResponse { pub(crate) struct ErrorMessage {
error: String, error: String,
} }
impl ErrorResponse { impl ErrorMessage {
/// Not Found Error /// Not Found Error
pub fn model_not_found() -> (StatusCode, Json<ErrorResponse>) { pub fn model_not_found() -> ErrorResponse {
( (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
Json(ErrorResponse { Json(ErrorMessage {
error: "Model not found".to_string(), error: "Model not found".to_string(),
}), }),
) )
...@@ -56,10 +65,10 @@ impl ErrorResponse { ...@@ -56,10 +65,10 @@ impl ErrorResponse {
/// Service Unavailable /// Service Unavailable
/// This is returned when the service is live, but not ready. /// This is returned when the service is live, but not ready.
pub fn _service_unavailable() -> (StatusCode, Json<ErrorResponse>) { pub fn _service_unavailable() -> ErrorResponse {
( (
StatusCode::SERVICE_UNAVAILABLE, StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse { Json(ErrorMessage {
error: "Service is not ready".to_string(), error: "Service is not ready".to_string(),
}), }),
) )
...@@ -69,11 +78,11 @@ impl ErrorResponse { ...@@ -69,11 +78,11 @@ impl ErrorResponse {
/// Return this error when the service encounters an internal error. /// Return this error when the service encounters an internal error.
/// We should return a generic message to the client instead of the real error. /// We should return a generic message to the client instead of the real error.
/// Internal Services errors are the result of misconfiguration or bugs in the service. /// Internal Services errors are the result of misconfiguration or bugs in the service.
pub fn internal_server_error(msg: &str) -> (StatusCode, Json<ErrorResponse>) { pub fn internal_server_error(msg: &str) -> ErrorResponse {
tracing::error!("Internal server error: {msg}"); tracing::error!("Internal server error: {msg}");
( (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse { Json(ErrorMessage {
error: msg.to_string(), error: msg.to_string(),
}), }),
) )
...@@ -82,11 +91,11 @@ impl ErrorResponse { ...@@ -82,11 +91,11 @@ impl ErrorResponse {
/// Not Implemented Error /// Not Implemented Error
/// Return this error when the client requests a feature that is not yet implemented. /// Return this error when the client requests a feature that is not yet implemented.
/// This should be used for features that are planned but not available. /// This should be used for features that are planned but not available.
pub fn not_implemented_error(msg: &str) -> (StatusCode, Json<ErrorResponse>) { pub fn not_implemented_error(msg: &str) -> ErrorResponse {
tracing::error!("Not Implemented error: {msg}"); tracing::error!("Not Implemented error: {msg}");
( (
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
Json(ErrorResponse { Json(ErrorMessage {
error: msg.to_string(), error: msg.to_string(),
}), }),
) )
...@@ -94,31 +103,56 @@ impl ErrorResponse { ...@@ -94,31 +103,56 @@ impl ErrorResponse {
/// The OAI endpoints call an [`dynamo.runtime::engine::AsyncEngine`] which are specialized to return /// The OAI endpoints call an [`dynamo.runtime::engine::AsyncEngine`] which are specialized to return
/// an [`anyhow::Error`]. This method will convert the [`anyhow::Error`] into an [`HttpError`]. /// an [`anyhow::Error`]. This method will convert the [`anyhow::Error`] into an [`HttpError`].
/// If successful, it will return the [`HttpError`] as an [`ErrorResponse::internal_server_error`] /// If successful, it will return the [`HttpError`] as an [`ErrorMessage::internal_server_error`]
/// with the details of the error. /// with the details of the error.
pub fn from_anyhow(err: anyhow::Error, alt_msg: &str) -> (StatusCode, Json<ErrorResponse>) { pub fn from_anyhow(err: anyhow::Error, alt_msg: &str) -> ErrorResponse {
match err.downcast::<HttpError>() { match err.downcast::<HttpError>() {
Ok(http_error) => ErrorResponse::from_http_error(http_error), Ok(http_error) => ErrorMessage::from_http_error(http_error),
Err(err) => ErrorResponse::internal_server_error(&format!("{alt_msg}: {err}")), Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err}")),
} }
} }
/// Implementers should only be able to throw 400-499 errors. /// Implementers should only be able to throw 400-499 errors.
pub fn from_http_error(err: HttpError) -> (StatusCode, Json<ErrorResponse>) { pub fn from_http_error(err: HttpError) -> ErrorResponse {
if err.code < 400 || err.code >= 500 { if err.code < 400 || err.code >= 500 {
return ErrorResponse::internal_server_error(&err.message); return ErrorMessage::internal_server_error(&err.message);
} }
match StatusCode::from_u16(err.code) { match StatusCode::from_u16(err.code) {
Ok(code) => (code, Json(ErrorResponse { error: err.message })), Ok(code) => (code, Json(ErrorMessage { error: err.message })),
Err(_) => ErrorResponse::internal_server_error(&err.message), Err(_) => ErrorMessage::internal_server_error(&err.message),
} }
} }
} }
impl From<HttpError> for ErrorResponse { impl From<HttpError> for ErrorMessage {
fn from(err: HttpError) -> Self { fn from(err: HttpError) -> Self {
ErrorResponse { error: err.message } ErrorMessage { error: err.message }
}
}
/// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present
fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String {
// Try to get the request ID from the primary source
if let Some(primary) = primary {
if let Ok(uuid) = uuid::Uuid::parse_str(primary) {
return uuid.to_string();
}
} }
// Try to get the request ID header as a string slice
let request_id_opt = headers
.get(DYNAMO_REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok());
// Try to parse the request ID as a UUID, or generate a new one if missing/invalid
let uuid = match request_id_opt {
Some(request_id) => {
uuid::Uuid::parse_str(request_id).unwrap_or_else(|_| uuid::Uuid::new_v4())
}
None => uuid::Uuid::new_v4(),
};
uuid.to_string()
} }
/// OpenAI Completions Request Handler /// OpenAI Completions Request Handler
...@@ -129,11 +163,46 @@ impl From<HttpError> for ErrorResponse { ...@@ -129,11 +163,46 @@ impl From<HttpError> for ErrorResponse {
/// ///
/// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For /// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For
/// non-streaming requests, we will fold the stream into a single response as part of this handler. /// non-streaming requests, we will fold the stream into a single response as part of this handler.
#[tracing::instrument(skip_all)] async fn handler_completions(
async fn completions(
State(state): State<Arc<service_v2::State>>, State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateCompletionRequest>, Json(request): Json<NvCreateCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
// create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let context = request.context();
// create the connection handles
let (mut connection_handle, stream_handle) = create_connection_monitor(context.clone()).await;
// possibly long running task
// if this returns a streaming response, the stream handle will be armed and captured by the response stream
let response = tokio::spawn(completions(state, request, stream_handle))
.await
.map_err(|e| {
ErrorMessage::internal_server_error(&format!(
"Failed to await chat completions task: {:?}",
e,
))
})?;
// if we got here, then we will return a response and the potentially long running task has completed successfully
// without need to be cancelled.
connection_handle.disarm();
response
}
#[tracing::instrument(skip_all)]
async fn completions(
state: Arc<service_v2::State>,
request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
...@@ -144,15 +213,10 @@ async fn completions( ...@@ -144,15 +213,10 @@ async fn completions(
let streaming = request.inner.stream.unwrap_or(false); let streaming = request.inner.stream.unwrap_or(false);
// update the request to always stream // update the request to always stream
let inner = async_openai::types::CreateCompletionRequest { let request = request.map(|mut req| {
stream: Some(true), req.inner.stream = Some(true);
..request.inner req
}; });
let request = NvCreateCompletionRequest {
inner,
nvext: request.nvext,
};
// todo - make the protocols be optional for model name // todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default // todo - when optional, if none, apply a default
...@@ -162,7 +226,7 @@ async fn completions( ...@@ -162,7 +226,7 @@ async fn completions(
let engine = state let engine = state
.manager() .manager()
.get_completions_engine(model) .get_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let mut inflight_guard = let mut inflight_guard =
state state
...@@ -171,27 +235,43 @@ async fn completions( ...@@ -171,27 +235,43 @@ async fn completions(
let mut response_collector = state.metrics_clone().create_response_collector(model); let mut response_collector = state.metrics_clone().create_response_collector(model);
// setup context // prepare to process any annotations
// todo - inherit request_id from distributed trace details let annotations = request.annotations();
let request = Context::with_id(request, request_id.clone());
// issue the generate call on the engine // issue the generate call on the engine
let stream = engine let stream = engine
.generate(request) .generate(request)
.await .await
.map_err(|e| ErrorResponse::from_anyhow(e, "Failed to generate completions"))?; .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
// capture the context to cancel the stream if the client disconnects // capture the context to cancel the stream if the client disconnects
let ctx = stream.context(); let ctx = stream.context();
// todo - tap the stream and propagate request level metrics let annotations = annotations.map_or(Vec::new(), |annotations| {
// note - we might do this as part of the post processing set to make it more generic annotations
.iter()
.filter_map(|annotation| {
if annotation == ANNOTATION_REQUEST_ID {
Annotated::<NvCreateCompletionResponse>::from_annotation(
ANNOTATION_REQUEST_ID,
&request_id,
)
.ok()
} else {
None
}
})
.collect::<Vec<_>>()
});
// apply any annotations to the front of the stream
let stream = stream::iter(annotations).chain(stream);
if streaming { if streaming {
let stream = stream.map(move |response| { let stream = stream.map(move |response| {
process_event_converter(EventConverter::from(response), &mut response_collector) process_event_converter(EventConverter::from(response), &mut response_collector)
}); });
let stream = monitor_for_disconnects(stream.boxed(), ctx, inflight_guard).await; let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream); let mut sse_stream = Sse::new(stream);
...@@ -202,7 +282,7 @@ async fn completions( ...@@ -202,7 +282,7 @@ async fn completions(
Ok(sse_stream.into_response()) Ok(sse_stream.into_response())
} else { } else {
// TODO: report ISL/OSL for non-streaming requests // TODO: report ISL/OSL for non-streaming requests
let response = NvCreateCompletionResponse::from_annotated_stream(stream.into()) let response = NvCreateCompletionResponse::from_annotated_stream(stream)
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!( tracing::error!(
...@@ -210,7 +290,7 @@ async fn completions( ...@@ -210,7 +290,7 @@ async fn completions(
request_id, request_id,
e e
); );
ErrorResponse::internal_server_error("Failed to fold completions stream") ErrorMessage::internal_server_error("Failed to fold completions stream")
})?; })?;
inflight_guard.mark_ok(); inflight_guard.mark_ok();
...@@ -222,7 +302,7 @@ async fn completions( ...@@ -222,7 +302,7 @@ async fn completions(
async fn embeddings( async fn embeddings(
State(state): State<Arc<service_v2::State>>, State(state): State<Arc<service_v2::State>>,
Json(request): Json<NvCreateEmbeddingRequest>, Json(request): Json<NvCreateEmbeddingRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
...@@ -240,7 +320,7 @@ async fn embeddings( ...@@ -240,7 +320,7 @@ async fn embeddings(
let engine = state let engine = state
.manager() .manager()
.get_embeddings_engine(model) .get_embeddings_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
// this will increment the inflight gauge for the model // this will increment the inflight gauge for the model
let mut inflight = let mut inflight =
...@@ -256,11 +336,11 @@ async fn embeddings( ...@@ -256,11 +336,11 @@ async fn embeddings(
let stream = engine let stream = engine
.generate(request) .generate(request)
.await .await
.map_err(|e| ErrorResponse::from_anyhow(e, "Failed to generate embeddings"))?; .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate embeddings"))?;
// Embeddings are typically returned as a single response (non-streaming) // Embeddings are typically returned as a single response (non-streaming)
// so we fold the stream into a single response // so we fold the stream into a single response
let response = NvCreateEmbeddingResponse::from_annotated_stream(stream.into()) let response = NvCreateEmbeddingResponse::from_annotated_stream(stream)
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!( tracing::error!(
...@@ -268,13 +348,45 @@ async fn embeddings( ...@@ -268,13 +348,45 @@ async fn embeddings(
request_id, request_id,
e e
); );
ErrorResponse::internal_server_error("Failed to fold embeddings stream") ErrorMessage::internal_server_error("Failed to fold embeddings stream")
})?; })?;
inflight.mark_ok(); inflight.mark_ok();
Ok(Json(response).into_response()) Ok(Json(response).into_response())
} }
async fn handler_chat_completions(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
headers: HeaderMap,
Json(request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
// create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let context = request.context();
// create the connection handles
let (mut connection_handle, stream_handle) = create_connection_monitor(context.clone()).await;
let response = tokio::spawn(chat_completions(state, template, request, stream_handle))
.await
.map_err(|e| {
ErrorMessage::internal_server_error(&format!(
"Failed to await chat completions task: {:?}",
e,
))
})?;
// if we got here, then we will return a response and the potentially long running task has completed successfully
// without need to be cancelled.
connection_handle.disarm();
response
}
/// OpenAI Chat Completions Request Handler /// OpenAI Chat Completions Request Handler
/// ///
/// This method will handle the incoming request for the /v1/chat/completions endpoint. The endpoint is a "source" /// This method will handle the incoming request for the /v1/chat/completions endpoint. The endpoint is a "source"
...@@ -283,21 +395,23 @@ async fn embeddings( ...@@ -283,21 +395,23 @@ async fn embeddings(
/// ///
/// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For /// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For
/// non-streaming requests, we will fold the stream into a single response as part of this handler. /// non-streaming requests, we will fold the stream into a single response as part of this handler.
#[tracing::instrument(skip_all)] #[tracing::instrument(level = "debug", skip_all, fields(request_id = %request.id()))]
async fn chat_completions( async fn chat_completions(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>, state: Arc<service_v2::State>,
Json(mut request): Json<NvCreateChatCompletionRequest>, template: Option<RequestTemplate>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { mut request: Context<NvCreateChatCompletionRequest>,
mut stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
let request_id = request.id().to_string();
// Handle unsupported fields - if Some(resp) is returned by // Handle unsupported fields - if Some(resp) is returned by
// validate_chat_completion_unsupported_fields, // validate_chat_completion_unsupported_fields,
// then a field was used that is unsupported. We will log an error message // then a field was used that is unsupported. We will log an error message
// and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed. // and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed.
if let Some(resp) = validate_chat_completion_unsupported_fields(&request) { validate_chat_completion_unsupported_fields(&request)?;
return Ok(resp.into_response());
}
// Apply template values if present // Apply template values if present
if let Some(template) = template { if let Some(template) = template {
...@@ -311,24 +425,16 @@ async fn chat_completions( ...@@ -311,24 +425,16 @@ async fn chat_completions(
request.inner.max_completion_tokens = Some(template.max_completion_tokens); request.inner.max_completion_tokens = Some(template.max_completion_tokens);
} }
} }
tracing::trace!("Received chat completions request: {:?}", request.inner); tracing::trace!("Received chat completions request: {:?}", request.content());
// todo - extract distributed tracing id and context id from headers
let request_id = uuid::Uuid::new_v4().to_string();
// todo - decide on default // todo - decide on default
let streaming = request.inner.stream.unwrap_or(false); let streaming = request.inner.stream.unwrap_or(false);
// update the request to always stream // update the request to always stream
let inner_request = async_openai::types::CreateChatCompletionRequest { let request = request.map(|mut req| {
stream: Some(true), req.inner.stream = Some(true);
..request.inner req
}; });
let request = NvCreateChatCompletionRequest {
inner: inner_request,
nvext: request.nvext,
};
// todo - make the protocols be optional for model name // todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default // todo - when optional, if none, apply a default
...@@ -340,7 +446,7 @@ async fn chat_completions( ...@@ -340,7 +446,7 @@ async fn chat_completions(
let engine = state let engine = state
.manager() .manager()
.get_chat_completions_engine(model) .get_chat_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let mut inflight_guard = let mut inflight_guard =
state state
...@@ -349,29 +455,45 @@ async fn chat_completions( ...@@ -349,29 +455,45 @@ async fn chat_completions(
let mut response_collector = state.metrics_clone().create_response_collector(model); let mut response_collector = state.metrics_clone().create_response_collector(model);
// setup context
// todo - inherit request_id from distributed trace details
let request = Context::with_id(request, request_id.clone());
tracing::trace!("Issuing generate call for chat completions"); tracing::trace!("Issuing generate call for chat completions");
let annotations = request.annotations();
// issue the generate call on the engine // issue the generate call on the engine
let stream = engine let stream = engine
.generate(request) .generate(request)
.await .await
.map_err(|e| ErrorResponse::from_anyhow(e, "Failed to generate completions"))?; .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
// capture the context to cancel the stream if the client disconnects // capture the context to cancel the stream if the client disconnects
let ctx = stream.context(); let ctx = stream.context();
// prepare any requested annotations
let annotations = annotations.map_or(Vec::new(), |annotations| {
annotations
.iter()
.filter_map(|annotation| {
if annotation == ANNOTATION_REQUEST_ID {
Annotated::from_annotation(ANNOTATION_REQUEST_ID, &request_id).ok()
} else {
None
}
})
.collect::<Vec<_>>()
});
// apply any annotations to the front of the stream
let stream = stream::iter(annotations).chain(stream);
// todo - tap the stream and propagate request level metrics // todo - tap the stream and propagate request level metrics
// note - we might do this as part of the post processing set to make it more generic // note - we might do this as part of the post processing set to make it more generic
if streaming { if streaming {
stream_handle.arm();
let stream = stream.map(move |response| { let stream = stream.map(move |response| {
process_event_converter(EventConverter::from(response), &mut response_collector) process_event_converter(EventConverter::from(response), &mut response_collector)
}); });
let stream = monitor_for_disconnects(stream.boxed(), ctx, inflight_guard).await; let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream); let mut sse_stream = Sse::new(stream);
...@@ -382,7 +504,7 @@ async fn chat_completions( ...@@ -382,7 +504,7 @@ async fn chat_completions(
Ok(sse_stream.into_response()) Ok(sse_stream.into_response())
} else { } else {
// TODO: report ISL/OSL for non-streaming requests // TODO: report ISL/OSL for non-streaming requests
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream.into()) let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!( tracing::error!(
...@@ -390,7 +512,7 @@ async fn chat_completions( ...@@ -390,7 +512,7 @@ async fn chat_completions(
"Failed to fold chat completions stream for: {:?}", "Failed to fold chat completions stream for: {:?}",
e e
); );
ErrorResponse::internal_server_error(&format!( ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}", "Failed to fold chat completions stream: {}",
e e
)) ))
...@@ -406,44 +528,77 @@ async fn chat_completions( ...@@ -406,44 +528,77 @@ async fn chat_completions(
#[allow(deprecated)] #[allow(deprecated)]
pub fn validate_chat_completion_unsupported_fields( pub fn validate_chat_completion_unsupported_fields(
request: &NvCreateChatCompletionRequest, request: &NvCreateChatCompletionRequest,
) -> Option<impl IntoResponse> { ) -> Result<(), ErrorResponse> {
let inner = &request.inner; let inner = &request.inner;
if inner.parallel_tool_calls == Some(true) { if inner.parallel_tool_calls == Some(true) {
return Some(ErrorResponse::not_implemented_error( return Err(ErrorMessage::not_implemented_error(
"`parallel_tool_calls: true` is not supported.", "`parallel_tool_calls: true` is not supported.",
)); ));
} }
if inner.stream == Some(true) && inner.tools.is_some() { if inner.stream == Some(true) && inner.tools.is_some() {
return Some(ErrorResponse::not_implemented_error( return Err(ErrorMessage::not_implemented_error(
"`stream: true` is not supported when `tools` are provided.", "`stream: true` is not supported when `tools` are provided.",
)); ));
} }
if inner.function_call.is_some() { if inner.function_call.is_some() {
return Some(ErrorResponse::not_implemented_error( return Err(ErrorMessage::not_implemented_error(
"`function_call` is deprecated. Please migrate to use `tool_choice` instead.", "`function_call` is deprecated. Please migrate to use `tool_choice` instead.",
)); ));
} }
if inner.functions.is_some() { if inner.functions.is_some() {
return Some(ErrorResponse::not_implemented_error( return Err(ErrorMessage::not_implemented_error(
"`functions` is deprecated. Please migrate to use `tools` instead.", "`functions` is deprecated. Please migrate to use `tools` instead.",
)); ));
} }
None Ok(())
} }
/// OpenAI Responses Request Handler /// OpenAI Responses Request Handler
/// ///
/// This method will handle the incoming request for the /v1/responses endpoint. /// This method will handle the incoming request for the /v1/responses endpoint.
#[tracing::instrument(skip_all)] async fn handler_responses(
async fn responses(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>, State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
Json(mut request): Json<NvCreateResponse>, headers: HeaderMap,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { Json(request): Json<NvCreateResponse>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
// create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let context = request.context();
// create the connection handles
let (mut connection_handle, _stream_handle) = create_connection_monitor(context.clone()).await;
let response = tokio::spawn(responses(state, template, request))
.await
.map_err(|e| {
ErrorMessage::internal_server_error(&format!(
"Failed to await chat completions task: {:?}",
e,
))
})?;
// if we got here, then we will return a response and the potentially long running task has completed successfully
// without need to be cancelled.
connection_handle.disarm();
response
}
#[tracing::instrument(level = "debug", skip_all, fields(request_id = %request.id()))]
async fn responses(
state: Arc<service_v2::State>,
template: Option<RequestTemplate>,
mut request: Context<NvCreateResponse>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
...@@ -476,21 +631,26 @@ async fn responses( ...@@ -476,21 +631,26 @@ async fn responses(
} }
tracing::trace!("Received chat completions request: {:?}", request.inner); tracing::trace!("Received chat completions request: {:?}", request.inner);
let request_id = uuid::Uuid::new_v4().to_string(); let request_id = request.id().to_string();
let (request, context) = request.into_parts();
// Convert NvCreateResponse --> NvCreateChatCompletionRequest let mut request: NvCreateChatCompletionRequest = request.try_into().map_err(|e| {
let request: NvCreateChatCompletionRequest = request.try_into().map_err(|e| {
tracing::error!( tracing::error!(
request_id, request_id,
"Failed to convert NvCreateResponse to NvCreateChatCompletionRequest: {:?}", "Failed to convert NvCreateResponse to NvCreateChatCompletionRequest: {:?}",
e e
); );
ErrorResponse::not_implemented_error(&format!( ErrorMessage::not_implemented_error(&format!(
"Only Input::Text(_) is currently supported: {}", "Only Input::Text(_) is currently supported: {}",
e e
)) ))
})?; })?;
let request = context.map(|mut _req| {
request.inner.stream = Some(false);
request
});
let model = &request.inner.model; let model = &request.inner.model;
tracing::trace!("Getting chat completions engine for model: {}", model); tracing::trace!("Getting chat completions engine for model: {}", model);
...@@ -498,7 +658,7 @@ async fn responses( ...@@ -498,7 +658,7 @@ async fn responses(
let engine = state let engine = state
.manager() .manager()
.get_chat_completions_engine(model) .get_chat_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let mut inflight_guard = let mut inflight_guard =
state state
...@@ -507,18 +667,16 @@ async fn responses( ...@@ -507,18 +667,16 @@ async fn responses(
let _response_collector = state.metrics_clone().create_response_collector(model); let _response_collector = state.metrics_clone().create_response_collector(model);
let request = Context::with_id(request, request_id.clone());
tracing::trace!("Issuing generate call for chat completions"); tracing::trace!("Issuing generate call for chat completions");
// issue the generate call on the engine // issue the generate call on the engine
let stream = engine let stream = engine
.generate(request) .generate(request)
.await .await
.map_err(|e| ErrorResponse::from_anyhow(e, "Failed to generate completions"))?; .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
// TODO: handle streaming, currently just unary // TODO: handle streaming, currently just unary
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream.into()) let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!( tracing::error!(
...@@ -526,7 +684,7 @@ async fn responses( ...@@ -526,7 +684,7 @@ async fn responses(
"Failed to fold chat completions stream for: {:?}", "Failed to fold chat completions stream for: {:?}",
e e
); );
ErrorResponse::internal_server_error(&format!( ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}", "Failed to fold chat completions stream: {}",
e e
)) ))
...@@ -539,7 +697,7 @@ async fn responses( ...@@ -539,7 +697,7 @@ async fn responses(
"Failed to convert NvCreateChatCompletionResponse to NvResponse: {:?}", "Failed to convert NvCreateChatCompletionResponse to NvResponse: {:?}",
e e
); );
ErrorResponse::internal_server_error("Failed to convert internal response") ErrorMessage::internal_server_error("Failed to convert internal response")
})?; })?;
inflight_guard.mark_ok(); inflight_guard.mark_ok();
...@@ -552,7 +710,7 @@ pub fn validate_response_input_is_text_only( ...@@ -552,7 +710,7 @@ pub fn validate_response_input_is_text_only(
) -> Option<impl IntoResponse> { ) -> Option<impl IntoResponse> {
match &request.inner.input { match &request.inner.input {
async_openai::types::responses::Input::Text(_) => None, async_openai::types::responses::Input::Text(_) => None,
_ => Some(ErrorResponse::not_implemented_error("Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.")), _ => Some(ErrorMessage::not_implemented_error("Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.")),
} }
} }
...@@ -564,87 +722,87 @@ pub fn validate_response_unsupported_fields( ...@@ -564,87 +722,87 @@ pub fn validate_response_unsupported_fields(
let inner = &request.inner; let inner = &request.inner;
if inner.background == Some(true) { if inner.background == Some(true) {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`background: true` is not supported.", "`background: true` is not supported.",
)); ));
} }
if inner.include.is_some() { if inner.include.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`include` is not supported.", "`include` is not supported.",
)); ));
} }
if inner.instructions.is_some() { if inner.instructions.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`instructions` is not supported.", "`instructions` is not supported.",
)); ));
} }
if inner.max_tool_calls.is_some() { if inner.max_tool_calls.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`max_tool_calls` is not supported.", "`max_tool_calls` is not supported.",
)); ));
} }
if inner.metadata.is_some() { if inner.metadata.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`metadata` is not supported.", "`metadata` is not supported.",
)); ));
} }
if inner.parallel_tool_calls == Some(true) { if inner.parallel_tool_calls == Some(true) {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`parallel_tool_calls: true` is not supported.", "`parallel_tool_calls: true` is not supported.",
)); ));
} }
if inner.previous_response_id.is_some() { if inner.previous_response_id.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`previous_response_id` is not supported.", "`previous_response_id` is not supported.",
)); ));
} }
if inner.prompt.is_some() { if inner.prompt.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`prompt` is not supported.", "`prompt` is not supported.",
)); ));
} }
if inner.reasoning.is_some() { if inner.reasoning.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`reasoning` is not supported.", "`reasoning` is not supported.",
)); ));
} }
if inner.service_tier.is_some() { if inner.service_tier.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`service_tier` is not supported.", "`service_tier` is not supported.",
)); ));
} }
if inner.store == Some(true) { if inner.store == Some(true) {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`store: true` is not supported.", "`store: true` is not supported.",
)); ));
} }
if inner.stream == Some(true) { if inner.stream == Some(true) {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`stream: true` is not supported.", "`stream: true` is not supported.",
)); ));
} }
if inner.text.is_some() { if inner.text.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`text` is not supported.", "`text` is not supported.",
)); ));
} }
if inner.tool_choice.is_some() { if inner.tool_choice.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`tool_choice` is not supported.", "`tool_choice` is not supported.",
)); ));
} }
if inner.tools.is_some() { if inner.tools.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`tools` is not supported.", "`tools` is not supported.",
)); ));
} }
if inner.truncation.is_some() { if inner.truncation.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`truncation` is not supported.", "`truncation` is not supported.",
)); ));
} }
if inner.user.is_some() { if inner.user.is_some() {
return Some(ErrorResponse::not_implemented_error( return Some(ErrorMessage::not_implemented_error(
"`user` is not supported.", "`user` is not supported.",
)); ));
} }
...@@ -654,9 +812,9 @@ pub fn validate_response_unsupported_fields( ...@@ -654,9 +812,9 @@ pub fn validate_response_unsupported_fields(
// todo - abstract this to the top level lib.rs to be reused // todo - abstract this to the top level lib.rs to be reused
// todo - move the service_observer to its own state/arc // todo - move the service_observer to its own state/arc
fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), ErrorResponse> {
// if state.service_observer.stage() != ServiceStage::Ready { // if state.service_observer.stage() != ServiceStage::Ready {
// return Err(ErrorResponse::service_unavailable()); // return Err(ErrorMessage::service_unavailable());
// } // }
Ok(()) Ok(())
} }
...@@ -676,7 +834,7 @@ fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), (StatusCode, Json< ...@@ -676,7 +834,7 @@ fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), (StatusCode, Json<
/// } /// }
async fn list_models_openai( async fn list_models_openai(
State(state): State<Arc<service_v2::State>>, State(state): State<Arc<service_v2::State>>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, ErrorResponse> {
check_ready(&state)?; check_ready(&state)?;
let created = SystemTime::now() let created = SystemTime::now()
...@@ -716,45 +874,6 @@ struct ModelListing { ...@@ -716,45 +874,6 @@ struct ModelListing {
owned_by: String, owned_by: String,
} }
/// This method will consume a stream of SSE events and forward them to a new stream defined by a tokio channel.
/// In this way, if the downstream is dropped, then the upstream will be unable to send any more events. This is
/// how we can monitor for disconnects and stop the generation of completions.
///
/// If a disconnect is detected, then the context will issue a `stop_generating` call to the context which will
/// propagate the cancellation signal to the backend.
async fn monitor_for_disconnects(
stream: Pin<
Box<dyn Stream<Item = Result<axum::response::sse::Event, axum::Error>> + std::marker::Send>,
>,
context: Arc<dyn AsyncEngineContext>,
mut inflight_guard: InflightGuard,
) -> ReceiverStream<Result<Event, axum::Error>> {
let (tx, rx) = tokio::sync::mpsc::channel(8);
tokio::spawn(async move {
let mut stream = stream;
while let Some(event) = stream.next().await {
let event = match event {
Ok(event) => Ok(event),
Err(err) => Ok(Event::default().event("error").comment(err.to_string())),
};
if (tx.send(event).await).is_err() {
tracing::trace!("Forwarding SSE stream was dropped; breaking loop");
context.stop_generating();
break;
}
}
// Stream completed successfully - mark as ok
if tx.send(Ok(Event::default().data("[DONE]"))).await.is_ok() {
inflight_guard.mark_ok();
}
});
ReceiverStream::new(rx)
}
struct EventConverter<T>(Annotated<T>); struct EventConverter<T>(Annotated<T>);
impl<T> From<Annotated<T>> for EventConverter<T> { impl<T> From<Annotated<T>> for EventConverter<T> {
...@@ -816,7 +935,7 @@ pub fn completions_router( ...@@ -816,7 +935,7 @@ pub fn completions_router(
let path = path.unwrap_or("/v1/completions".to_string()); let path = path.unwrap_or("/v1/completions".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path); let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new() let router = Router::new()
.route(&path, post(completions)) .route(&path, post(handler_completions))
.with_state(state); .with_state(state);
(vec![doc], router) (vec![doc], router)
} }
...@@ -831,7 +950,7 @@ pub fn chat_completions_router( ...@@ -831,7 +950,7 @@ pub fn chat_completions_router(
let path = path.unwrap_or("/v1/chat/completions".to_string()); let path = path.unwrap_or("/v1/chat/completions".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path); let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new() let router = Router::new()
.route(&path, post(chat_completions)) .route(&path, post(handler_chat_completions))
.with_state((state, template)); .with_state((state, template));
(vec![doc], router) (vec![doc], router)
} }
...@@ -876,7 +995,7 @@ pub fn responses_router( ...@@ -876,7 +995,7 @@ pub fn responses_router(
let path = path.unwrap_or("/v1/responses".to_string()); let path = path.unwrap_or("/v1/responses".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path); let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new() let router = Router::new()
.route(&path, post(responses)) .route(&path, post(handler_responses))
.with_state((state, template)); .with_state((state, template));
(vec![doc], router) (vec![doc], router)
} }
...@@ -942,7 +1061,7 @@ mod tests { ...@@ -942,7 +1061,7 @@ mod tests {
#[test] #[test]
fn test_http_error_response_from_anyhow() { fn test_http_error_response_from_anyhow() {
let err = http_error_from_engine(400).unwrap_err(); let err = http_error_from_engine(400).unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE); let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(response.error, "custom error message"); assert_eq!(response.error, "custom error message");
} }
...@@ -950,17 +1069,17 @@ mod tests { ...@@ -950,17 +1069,17 @@ mod tests {
#[test] #[test]
fn test_error_response_from_anyhow_out_of_range() { fn test_error_response_from_anyhow_out_of_range() {
let err = http_error_from_engine(399).unwrap_err(); let err = http_error_from_engine(399).unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE); let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message"); assert_eq!(response.error, "custom error message");
let err = http_error_from_engine(500).unwrap_err(); let err = http_error_from_engine(500).unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE); let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message"); assert_eq!(response.error, "custom error message");
let err = http_error_from_engine(501).unwrap_err(); let err = http_error_from_engine(501).unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE); let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message"); assert_eq!(response.error, "custom error message");
} }
...@@ -968,7 +1087,7 @@ mod tests { ...@@ -968,7 +1087,7 @@ mod tests {
#[test] #[test]
fn test_other_error_response_from_anyhow() { fn test_other_error_response_from_anyhow() {
let err = other_error_from_engine().unwrap_err(); let err = other_error_from_engine().unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE); let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!( assert_eq!(
response.error, response.error,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
//! both publicly via the HTTP API and internally between Dynamo components. //! both publicly via the HTTP API and internally between Dynamo components.
//! //!
use futures::StreamExt; use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub mod codec; pub mod codec;
...@@ -49,12 +49,12 @@ pub trait ContentProvider { ...@@ -49,12 +49,12 @@ pub trait ContentProvider {
/// Converts of a stream of [codec::Message]s into a stream of [Annotated]s. /// Converts of a stream of [codec::Message]s into a stream of [Annotated]s.
pub fn convert_sse_stream<R>( pub fn convert_sse_stream<R>(
stream: DataStream<Result<codec::Message, codec::SseCodecError>>, stream: impl Stream<Item = Result<codec::Message, codec::SseCodecError>>,
) -> DataStream<Annotated<R>> ) -> impl Stream<Item = Annotated<R>>
where where
R: for<'de> Deserialize<'de> + Serialize, R: for<'de> Deserialize<'de> + Serialize,
{ {
let stream = stream.map(|message| match message { stream.map(|message| match message {
Ok(message) => { Ok(message) => {
let delta = Annotated::<R>::try_from(message); let delta = Annotated::<R>::try_from(message);
match delta { match delta {
...@@ -63,6 +63,5 @@ where ...@@ -63,6 +63,5 @@ where
} }
} }
Err(e) => Annotated::from_error(e.to_string()), Err(e) => Annotated::from_error(e.to_string()),
}); })
Box::pin(stream)
} }
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use futures::StreamExt; use futures::{Stream, StreamExt};
use std::collections::HashMap; use std::collections::HashMap;
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
...@@ -22,7 +22,6 @@ use crate::protocols::{ ...@@ -22,7 +22,6 @@ use crate::protocols::{
convert_sse_stream, Annotated, convert_sse_stream, Annotated,
}; };
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
use dynamo_runtime::engine::DataStream; use dynamo_runtime::engine::DataStream;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
...@@ -95,7 +94,7 @@ impl DeltaAggregator { ...@@ -95,7 +94,7 @@ impl DeltaAggregator {
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation is successful. /// * `Ok(NvCreateChatCompletionResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing. /// * `Err(String)` if an error occurs during processing.
pub async fn apply( pub async fn apply(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>, stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
...@@ -260,7 +259,7 @@ impl NvCreateChatCompletionResponse { ...@@ -260,7 +259,7 @@ impl NvCreateChatCompletionResponse {
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds. /// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs. /// * `Err(String)` if an error occurs.
pub async fn from_annotated_stream( pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>, stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await DeltaAggregator::apply(stream).await
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use anyhow::Result; use anyhow::Result;
use futures::StreamExt; use futures::{Stream, StreamExt};
use super::NvCreateCompletionResponse; use super::NvCreateCompletionResponse;
use crate::protocols::{ use crate::protocols::{
...@@ -64,7 +64,7 @@ impl DeltaAggregator { ...@@ -64,7 +64,7 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`]. /// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub async fn apply( pub async fn apply(
stream: DataStream<Annotated<NvCreateCompletionResponse>>, stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
) -> Result<NvCreateCompletionResponse> { ) -> Result<NvCreateCompletionResponse> {
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
...@@ -183,7 +183,7 @@ impl NvCreateCompletionResponse { ...@@ -183,7 +183,7 @@ impl NvCreateCompletionResponse {
} }
pub async fn from_annotated_stream( pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateCompletionResponse>>, stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
) -> Result<NvCreateCompletionResponse> { ) -> Result<NvCreateCompletionResponse> {
DeltaAggregator::apply(stream).await DeltaAggregator::apply(stream).await
} }
......
...@@ -20,7 +20,7 @@ use crate::protocols::{ ...@@ -20,7 +20,7 @@ use crate::protocols::{
}; };
use dynamo_runtime::engine::DataStream; use dynamo_runtime::engine::DataStream;
use futures::StreamExt; use futures::{Stream, StreamExt};
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single /// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler /// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
...@@ -58,7 +58,7 @@ impl DeltaAggregator { ...@@ -58,7 +58,7 @@ impl DeltaAggregator {
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation is successful. /// * `Ok(NvCreateEmbeddingResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing. /// * `Err(String)` if an error occurs during processing.
pub async fn apply( pub async fn apply(
stream: DataStream<Annotated<NvCreateEmbeddingResponse>>, stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
) -> Result<NvCreateEmbeddingResponse, String> { ) -> Result<NvCreateEmbeddingResponse, String> {
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
...@@ -133,7 +133,7 @@ impl NvCreateEmbeddingResponse { ...@@ -133,7 +133,7 @@ impl NvCreateEmbeddingResponse {
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds. /// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs. /// * `Err(String)` if an error occurs.
pub async fn from_annotated_stream( pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateEmbeddingResponse>>, stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
) -> Result<NvCreateEmbeddingResponse, String> { ) -> Result<NvCreateEmbeddingResponse, String> {
DeltaAggregator::apply(stream).await DeltaAggregator::apply(stream).await
} }
......
...@@ -28,6 +28,8 @@ use dynamo_llm::http::{ ...@@ -28,6 +28,8 @@ use dynamo_llm::http::{
}, },
}; };
use dynamo_llm::protocols::{ use dynamo_llm::protocols::{
codec::SseLineCodec,
convert_sse_stream,
openai::{ openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
...@@ -45,11 +47,31 @@ use futures::StreamExt; ...@@ -45,11 +47,31 @@ use futures::StreamExt;
use prometheus::{proto::MetricType, Registry}; use prometheus::{proto::MetricType, Registry};
use reqwest::StatusCode; use reqwest::StatusCode;
use rstest::*; use rstest::*;
use std::sync::Arc; use std::{io::Cursor, sync::Arc};
use tokio::time::timeout;
use tokio_util::codec::FramedRead;
struct CounterEngine {} struct CounterEngine {}
#[allow(deprecated)] // Add a new long-running test engine
struct LongRunningEngine {
delay_ms: u64,
cancelled: Arc<std::sync::atomic::AtomicBool>,
}
impl LongRunningEngine {
fn new(delay_ms: u64) -> Self {
Self {
delay_ms,
cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
fn was_cancelled(&self) -> bool {
self.cancelled.load(std::sync::atomic::Ordering::Acquire)
}
}
#[async_trait] #[async_trait]
impl impl
AsyncEngine< AsyncEngine<
...@@ -66,6 +88,7 @@ impl ...@@ -66,6 +88,7 @@ impl
let ctx = context.context(); let ctx = context.context();
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens // ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
#[allow(deprecated)]
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64; let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
// let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone()); // let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone());
...@@ -88,6 +111,54 @@ impl ...@@ -88,6 +111,54 @@ impl
} }
} }
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for LongRunningEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (_request, context) = request.transfer(());
let ctx = context.context();
tracing::info!(
"LongRunningEngine: Starting generation with {}ms delay",
self.delay_ms
);
let cancelled_flag = self.cancelled.clone();
let delay_ms = self.delay_ms;
let ctx_clone = ctx.clone();
let stream = async_stream::stream! {
// the stream can be dropped or it can be cancelled
// either way we consider this a cancellation
cancelled_flag.store(true, std::sync::atomic::Ordering::SeqCst);
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_millis(delay_ms)) => {
// the stream went to completion
cancelled_flag.store(false, std::sync::atomic::Ordering::SeqCst);
}
_ = ctx_clone.stopped() => {
cancelled_flag.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
yield Annotated::<NvCreateChatCompletionStreamResponse>::from_annotation("event.dynamo.test.sentinel", &"DONE".to_string()).expect("Failed to create annotated response");
};
Ok(ResponseStream::new(Box::pin(stream), ctx))
}
}
struct AlwaysFailEngine {} struct AlwaysFailEngine {}
#[async_trait] #[async_trait]
...@@ -880,3 +951,311 @@ async fn test_generic_byot_client( ...@@ -880,3 +951,311 @@ async fn test_generic_byot_client(
cancel_token.cancel(); cancel_token.cancel();
task.await.unwrap().unwrap(); task.await.unwrap().unwrap();
} }
#[rstest]
#[tokio::test]
async fn test_client_disconnect_cancellation_unary() {
let service = HttpService::builder().port(8993).build().unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(8993).await;
// Create a long-running engine (10 seconds)
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model("slow-model", long_running_engine.clone())
.unwrap();
let client = reqwest::Client::new();
let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"This will take a long time".to_string(),
),
name: None,
},
);
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("slow-model")
.messages(vec![message])
.stream(false) // Test unary response
.build()
.expect("Failed to build request");
// Start the request and cancel it after 1 second
let start_time = std::time::Instant::now();
let request_future = async {
client
.post("http://localhost:8993/v1/chat/completions")
.json(&request)
.send()
.await
};
// Use timeout to simulate client disconnect after 1 second
let result = timeout(std::time::Duration::from_millis(1000), request_future).await;
let elapsed = start_time.elapsed();
// The request should timeout (simulating client disconnect)
assert!(result.is_err(), "Request should have timed out");
// Give the service a moment to detect the disconnect and propagate cancellation
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
// Verify the engine was cancelled
assert!(
long_running_engine.was_cancelled(),
"Engine should have been cancelled due to client disconnect"
);
// Verify cancellation happened quickly (within 2 seconds, not the full 10 seconds)
assert!(
elapsed < std::time::Duration::from_secs(2),
"Cancellation should have propagated quickly, took {:?}",
elapsed
);
tracing::info!(
"✅ Client disconnect test passed! Request cancelled in {:?}, engine detected cancellation",
elapsed
);
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[rstest]
#[tokio::test]
async fn test_client_disconnect_cancellation_streaming() {
dynamo_runtime::logging::init();
let service = HttpService::builder().port(8994).build().unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(8994).await;
// Create a long-running engine (10 seconds)
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model("slow-stream-model", long_running_engine.clone())
.unwrap();
let client = reqwest::Client::new();
let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"This will stream for a long time".to_string(),
),
name: None,
},
);
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("slow-stream-model")
.messages(vec![message])
.stream(true) // Test streaming response
.build()
.expect("Failed to build request");
// Start the request and cancel it after 1 second
let start_time = std::time::Instant::now();
let request_future = async {
let response = client
.post("http://localhost:8994/v1/chat/completions")
.json(&request)
.send()
.await
.unwrap();
// Start reading the stream, then drop it to simulate client disconnect
let mut stream = response.bytes_stream();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
// Read one chunk then drop the stream (simulating client disconnect)
let _ = StreamExt::next(&mut stream).await;
// Stream gets dropped here when function exits
};
// Use timeout to simulate the streaming request timing out
let _result = timeout(std::time::Duration::from_millis(1500), request_future).await;
let elapsed = start_time.elapsed();
// Give the service time to detect the disconnect
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
// Verify the engine was cancelled
assert!(
long_running_engine.was_cancelled(),
"Engine should have been cancelled due to streaming client disconnect"
);
// Verify cancellation happened reasonably quickly
assert!(
elapsed < std::time::Duration::from_secs(3),
"Stream cancellation should have propagated reasonably quickly, took {:?}",
elapsed
);
tracing::info!(
"✅ Streaming client disconnect test passed! Stream cancelled in {:?}, engine detected cancellation",
elapsed
);
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[rstest]
#[tokio::test]
async fn test_request_id_annotation() {
// TODO(ryan): make better fixtures, this is too much to test sometime so simple
dynamo_runtime::logging::init();
let service = HttpService::builder().port(8995).build().unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(8995).await;
// Add a counter engine for this test
let counter_engine = Arc::new(CounterEngine {});
manager
.add_chat_completions_model("test-model", counter_engine)
.unwrap();
// Create reqwest client directly
let client = reqwest::Client::new();
// Generate a UUID for the request ID
let request_uuid = uuid::Uuid::new_v4();
// Create the request JSON directly
let request_json = serde_json::json!({
"model": "test-model",
"messages": [
{
"role": "user",
"content": "Test request with annotation"
}
],
"stream": true,
"max_tokens": 50,
"nvext": {
"annotations": ["request_id"]
}
});
// Make the streaming request with custom header
let response = client
.post("http://localhost:8995/v1/chat/completions")
.header("x-dynamo-request-id", request_uuid.to_string())
.json(&request_json)
.send()
.await
.expect("Request should succeed");
assert!(
response.status().is_success(),
"Response should be successful"
);
// Collect the entire response body as bytes first
let body_bytes = response
.bytes()
.await
.expect("Failed to read response body");
let body_text = String::from_utf8_lossy(&body_bytes);
// Create a cursor from the text and use SseLineCodec to parse it
let cursor = Cursor::new(body_text.to_string());
let framed = FramedRead::new(cursor, SseLineCodec::new());
let annotated_stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(framed);
// Look for the annotation in the stream
let mut found_request_id_annotation = false;
let mut received_request_id = None;
// Process the annotated stream and look for the request_id annotation
let mut annotated_stream = std::pin::pin!(annotated_stream);
while let Some(annotated_response) = annotated_stream.next().await {
// Check if this is a request_id annotation
if let Some(event) = &annotated_response.event {
if event == "request_id" {
found_request_id_annotation = true;
// Extract the request ID from the annotation
if let Some(comments) = &annotated_response.comment {
if let Some(comment) = comments.first() {
// The comment contains a JSON-encoded string, so we need to parse it
if let Ok(parsed_value) = serde_json::from_str::<String>(comment) {
received_request_id = Some(parsed_value);
} else {
// Fallback: remove quotes manually if JSON parsing fails
received_request_id = Some(comment.trim_matches('"').to_string());
}
}
}
break;
}
}
}
// Verify we found the annotation
assert!(
found_request_id_annotation,
"Should have received request_id annotation in the stream"
);
// Verify the request ID matches what we sent
assert!(
received_request_id.is_some(),
"Should have received the request ID in the annotation"
);
let received_uuid_str = received_request_id.unwrap();
assert_eq!(
received_uuid_str,
request_uuid.to_string(),
"Received request ID should match the one we sent: expected {}, got {}",
request_uuid,
received_uuid_str
);
tracing::info!(
"✅ Request ID annotation test passed! Sent UUID: {}, Received: {}",
request_uuid,
received_uuid_str
);
cancel_token.cancel();
task.await.unwrap().unwrap();
}
...@@ -75,10 +75,16 @@ impl<T: Send + Sync + 'static> Context<T> { ...@@ -75,10 +75,16 @@ impl<T: Send + Sync + 'static> Context<T> {
} }
} }
/// Get the id of the context
pub fn id(&self) -> &str { pub fn id(&self) -> &str {
self.controller.id() self.controller.id()
} }
/// Get the content of the context
pub fn content(&self) -> &T {
&self.current
}
pub fn controller(&self) -> &Controller { pub fn controller(&self) -> &Controller {
&self.controller &self.controller
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment