Unverified Commit b4ddca99 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Failure Detection while Responses are returning (#1671)

parent bd91dca6
......@@ -214,7 +214,7 @@ struct Endpoint {
#[pyclass]
#[derive(Clone)]
struct Client {
router: rs::pipeline::PushRouter<serde_json::Value, serde_json::Value>,
router: rs::pipeline::PushRouter<serde_json::Value, RsAnnotated<serde_json::Value>>,
}
#[pyclass(eq, eq_int)]
......@@ -485,13 +485,12 @@ impl Endpoint {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let client = inner.client().await.map_err(to_pyerr)?;
let push_router =
rs::pipeline::PushRouter::<serde_json::Value, serde_json::Value>::from_client(
client,
Default::default(),
)
.await
.map_err(to_pyerr)?;
let push_router = rs::pipeline::PushRouter::<
serde_json::Value,
RsAnnotated<serde_json::Value>,
>::from_client(client, Default::default())
.await
.map_err(to_pyerr)?;
Ok(Client {
router: push_router,
})
......@@ -757,23 +756,13 @@ impl Client {
}
async fn process_stream(
stream: EngineStream<serde_json::Value>,
stream: EngineStream<RsAnnotated<serde_json::Value>>,
tx: tokio::sync::mpsc::Sender<RsAnnotated<PyObject>>,
) {
let mut stream = stream;
while let Some(response) = stream.next().await {
// Convert the response to a PyObject using Python's GIL
// TODO: Remove the clone, but still log the full JSON string on error. But how?
let annotated: RsAnnotated<serde_json::Value> = match serde_json::from_value(
response.clone(),
) {
Ok(a) => a,
Err(err) => {
tracing::error!(%err, %response, "process_stream: Failed de-serializing JSON into RsAnnotated");
break;
}
};
let annotated: RsAnnotated<serde_json::Value> = response;
let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| {
let result = Python::with_gil(|py| match pythonize::pythonize(py, &data) {
Ok(pyobj) => Ok(pyobj.into()),
......
......@@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason;
use crate::protocols::TokenIdType;
use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>;
......@@ -134,6 +135,20 @@ impl LLMEngineOutput {
}
}
impl MaybeError for LLMEngineOutput {
fn from_err(err: Box<dyn std::error::Error>) -> Self {
LLMEngineOutput::error(format!("{:?}", err))
}
fn err(&self) -> Option<Box<dyn std::error::Error>> {
if let Some(FinishReason::Error(err_msg)) = &self.finish_reason {
Some(anyhow::Error::msg(err_msg.clone()).into())
} else {
None
}
}
}
/// Raw output from embedding engines containing embedding vectors
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct EmbeddingsEngineOutput {
......@@ -144,3 +159,26 @@ pub struct EmbeddingsEngineOutput {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_maybe_error() {
let output = LLMEngineOutput::stop();
assert!(output.err().is_none());
assert!(output.is_ok());
assert!(!output.is_err());
let output = LLMEngineOutput::error("Test error".to_string());
assert_eq!(format!("{}", output.err().unwrap()), "Test error");
assert!(!output.is_ok());
assert!(output.is_err());
let output = LLMEngineOutput::from_err(anyhow::Error::msg("Test error 2").into());
assert_eq!(format!("{}", output.err().unwrap()), "Test error 2");
assert!(!output.is_ok());
assert!(output.is_err());
}
}
......@@ -323,3 +323,54 @@ impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> {
pub trait PushWorkHandler: Send + Sync {
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
}
/*
/// `NetworkStreamWrapper` is a simple wrapper used to detect proper stream termination
/// in network communication between ingress and egress components.
///
/// **Purpose**: This wrapper solves the problem of detecting whether a stream ended
/// gracefully or was cut off prematurely (e.g., due to network issues).
///
/// **Design Rationale**:
/// - Cannot use `Annotated` directly because the generic type `U` varies:
/// - Sometimes `U = Annotated<...>`
/// - Sometimes `U = LLMEngineOutput<...>`
/// - Using `Annotated` would require double-wrapping like `Annotated<Annotated<...>>`
/// - A simple wrapper is cleaner and more straightforward
///
/// **Stream Flow**:
/// ```
/// At AsyncEngine:
/// response 1 -> response 2 -> response 3 -> <end>
///
/// Between ingress/egress:
/// response 1 <end=false> -> response 2 <end=false> -> response 3 <end=false> -> (null) <end=true>
///
/// At client:
/// response 1 -> response 2 -> response 3 -> <end>
/// ```
///
/// **Error Handling**:
/// If the stream is cut off before proper termination, the egress is responsible for
/// injecting an error response to communicate the incomplete stream to the client:
/// ```
/// At AsyncEngine:
/// response 1 -> ... <without end flag>
///
/// At egress:
/// response 1 <end=false> -> <stream ended without end flag -> convert to error>
///
/// At client:
/// response 1 -> error response
/// ```
///
/// The detection must be done at egress level because premature stream termination
/// can be due to network issues that only the egress component can detect.
*/
/// TODO: Detect end-of-stream using Server-Sent Events (SSE). This will be removed.
#[derive(Serialize, Deserialize, Debug)]
pub struct NetworkStreamWrapper<U> {
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<U>,
pub complete_final: bool,
}
......@@ -17,7 +17,8 @@ use async_nats::client::Client;
use tracing as log;
use super::*;
use crate::Result;
use crate::{protocols::maybe_error::MaybeError, Result};
use tokio_stream::{wrappers::ReceiverStream, StreamExt, StreamNotifyClose};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
......@@ -80,7 +81,7 @@ impl AddressedPushRouter {
impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
let request_id = request.context().id().to_string();
......@@ -160,16 +161,49 @@ where
.map_err(|_| PipelineError::DetatchedStreamReceiver)?
.map_err(PipelineError::ConnectionFailed)?;
let stream = tokio_stream::wrappers::ReceiverStream::new(response_stream.rx);
let stream = stream.filter_map(|msg| async move {
match serde_json::from_slice::<U>(&msg) {
Ok(r) => Some(r),
Err(err) => {
let json_str = String::from_utf8_lossy(&msg);
log::warn!(%err, %json_str, "Failed deserializing JSON to response");
None
// TODO: Detect end-of-stream using Server-Sent Events (SSE)
let mut is_complete_final = false;
let stream = tokio_stream::StreamNotifyClose::new(
tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
)
.filter_map(move |res| {
if let Some(res_bytes) = res {
if is_complete_final {
return Some(U::from_err(
Error::msg(
"Response received after generation ended - this should never happen",
)
.into(),
));
}
match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
Ok(item) => {
is_complete_final = item.complete_final;
if let Some(data) = item.data {
Some(data)
} else if is_complete_final {
None
} else {
Some(U::from_err(
Error::msg("Empty response received - this should never happen")
.into(),
))
}
}
Err(err) => {
// legacy log print
let json_str = String::from_utf8_lossy(&res_bytes);
log::warn!(%err, %json_str, "Failed deserializing JSON to response");
Some(U::from_err(Error::new(err).into()))
}
}
} else if is_complete_final {
None
} else {
Some(U::from_err(
Error::msg("Stream ended before generation completed").into(),
))
}
});
......
......@@ -13,6 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{AsyncEngineContextProvider, ResponseStream};
use crate::{
component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data},
pipeline::{
error::PipelineErrorExt, AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
},
protocols::maybe_error::MaybeError,
traits::DistributedRuntimeProvider,
};
use async_nats::client::{
RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
};
......@@ -27,15 +37,7 @@ use std::{
Arc,
},
};
use crate::{
component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data},
pipeline::{
error::PipelineErrorExt, AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
},
traits::DistributedRuntimeProvider,
};
use tokio_stream::StreamExt;
#[derive(Clone)]
pub struct PushRouter<T, U>
......@@ -94,7 +96,7 @@ async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPu
impl<T, U> PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
......@@ -109,51 +111,44 @@ where
/// Issue a request to the next available instance in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let slf = self;
let routing_algorithm = move || async move {
let counter = slf.round_robin_counter.fetch_add(1, Ordering::Relaxed);
let instance_id = {
let instances = slf.client.instances_avail().await;
let count = instances.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no instances found for endpoint {:?}",
slf.client.endpoint.etcd_root()
));
}
let offset = counter % count as u64;
instances[offset as usize].id()
};
tracing::trace!("round robin router selected {instance_id}");
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
Ok(instance_id)
let instance_id = {
let instances = self.client.instances_avail().await;
let count = instances.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no instances found for endpoint {:?}",
self.client.endpoint.etcd_root()
));
}
let offset = counter % count as u64;
instances[offset as usize].id()
};
self.generate_with_fault_tolerance(routing_algorithm, request)
tracing::trace!("round robin router selected {instance_id}");
self.generate_with_fault_detection(instance_id, request)
.await
}
/// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let slf = self;
let routing_algorithm = move || async move {
let instance_id = {
let instances = slf.client.instances_avail().await;
let count = instances.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no instances found for endpoint {:?}",
slf.client.endpoint.etcd_root()
));
}
let counter = rand::rng().random::<u64>();
let offset = counter % count as u64;
instances[offset as usize].id()
};
tracing::trace!("random router selected {instance_id}");
Ok(instance_id)
let instance_id = {
let instances = self.client.instances_avail().await;
let count = instances.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no instances found for endpoint {:?}",
self.client.endpoint.etcd_root()
));
}
let counter = rand::rng().random::<u64>();
let offset = counter % count as u64;
instances[offset as usize].id()
};
self.generate_with_fault_tolerance(routing_algorithm, request)
tracing::trace!("random router selected {instance_id}");
self.generate_with_fault_detection(instance_id, request)
.await
}
......@@ -163,22 +158,19 @@ where
request: SingleIn<T>,
instance_id: i64,
) -> anyhow::Result<ManyOut<U>> {
let slf = self;
let routing_algorithm = move || async move {
let found = {
let instances = slf.client.instances_avail().await;
instances.iter().any(|ep| ep.id() == instance_id)
};
if !found {
return Err(anyhow::anyhow!(
"instance_id={instance_id} not found for endpoint {:?}",
slf.client.endpoint.etcd_root()
));
}
Ok(instance_id)
let found = {
let instances = self.client.instances_avail().await;
instances.iter().any(|ep| ep.id() == instance_id)
};
self.generate_with_fault_tolerance(routing_algorithm, request)
if !found {
return Err(anyhow::anyhow!(
"instance_id={instance_id} not found for endpoint {:?}",
self.client.endpoint.etcd_root()
));
}
self.generate_with_fault_detection(instance_id, request)
.await
}
......@@ -190,29 +182,45 @@ where
self.addressed.generate(request).await
}
async fn generate_with_fault_tolerance<F, R>(
async fn generate_with_fault_detection(
&self,
routing_algorithm: F,
instance_id: i64,
request: SingleIn<T>,
) -> anyhow::Result<ManyOut<U>>
where
F: FnOnce() -> R,
R: Future<Output = anyhow::Result<i64>>,
{
let instance_id = routing_algorithm().await?;
) -> anyhow::Result<ManyOut<U>> {
let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
let stream = self.addressed.generate(request).await;
if let Some(err) = stream.as_ref().err() {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
if matches!(req_err.kind(), NatsNoResponders) {
self.client.report_instance_down(instance_id).await;
let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
match stream {
Ok(stream) => {
let engine_ctx = stream.context();
let client = self.client.clone();
let stream = stream.then(move |res| {
let mut report_instance_down: Option<(Client, i64)> = None;
if let Some(err) = res.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG {
report_instance_down = Some((client.clone(), instance_id));
}
}
async move {
if let Some((client, instance_id)) = report_instance_down {
client.report_instance_down(instance_id).await;
}
res
}
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
Err(err) => {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
if matches!(req_err.kind(), NatsNoResponders) {
self.client.report_instance_down(instance_id).await;
}
}
Err(err)
}
}
stream
}
}
......@@ -220,7 +228,7 @@ where
impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match self.client.instance_source.as_ref() {
......
......@@ -97,16 +97,37 @@ where
let context = stream.context();
// TODO: Detect end-of-stream using Server-Sent Events (SSE)
let mut send_complete_final = true;
while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp);
let resp_bytes = serde_json::to_vec(&resp)
let resp_wrapper = NetworkStreamWrapper {
data: Some(resp),
complete_final: false,
};
let resp_bytes = serde_json::to_vec(&resp_wrapper)
.expect("fatal error: invalid response object - this should never happen");
if (publisher.send(resp_bytes.into()).await).is_err() {
tracing::error!("Failed to publish response for stream {}", context.id());
context.stop_generating();
send_complete_final = false;
break;
}
}
if send_complete_final {
let resp_wrapper = NetworkStreamWrapper::<U> {
data: None,
complete_final: true,
};
let resp_bytes = serde_json::to_vec(&resp_wrapper)
.expect("fatal error: invalid response object - this should never happen");
if (publisher.send(resp_bytes.into()).await).is_err() {
tracing::error!(
"Failed to publish complete final for stream {}",
context.id()
);
}
}
Ok(())
}
......
......@@ -19,6 +19,7 @@ use std::str::FromStr;
use crate::pipeline::PipelineError;
pub mod annotated;
pub mod maybe_error;
pub type LeaseId = i64;
......
......@@ -15,6 +15,7 @@
use super::*;
use crate::{error, Result};
use maybe_error::MaybeError;
pub trait AnnotationsProvider {
fn annotations(&self) -> Option<Vec<String>>;
......@@ -28,7 +29,7 @@ pub trait AnnotationsProvider {
/// Our services have the option of returning an "annotated" stream, which allows use
/// to include additional information with each delta. This is useful for debugging,
/// performance benchmarking, and improved observability.
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Annotated<R> {
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<R>,
......@@ -146,6 +147,28 @@ impl<R> Annotated<R> {
}
}
impl<R> MaybeError for Annotated<R>
where
R: for<'de> Deserialize<'de> + Serialize,
{
fn from_err(err: Box<dyn std::error::Error>) -> Self {
Annotated::from_error(format!("{:?}", err))
}
fn err(&self) -> Option<Box<dyn std::error::Error>> {
if self.is_error() {
if let Some(comment) = &self.comment {
if !comment.is_empty() {
return Some(anyhow::Error::msg(comment.join("; ")).into());
}
}
Some(anyhow::Error::msg("unknown error").into())
} else {
None
}
}
}
// impl<R> Annotated<R>
// where
// R: for<'de> Deserialize<'de> + Serialize,
......@@ -166,3 +189,27 @@ impl<R> Annotated<R> {
// Box::pin(stream)
// }
// }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_maybe_error() {
let annotated = Annotated::from_data("Test data".to_string());
assert!(annotated.err().is_none());
assert!(annotated.is_ok());
assert!(!annotated.is_err());
let annotated = Annotated::<String>::from_error("Test error 2".to_string());
assert_eq!(format!("{}", annotated.err().unwrap()), "Test error 2");
assert!(!annotated.is_ok());
assert!(annotated.is_err());
let annotated =
Annotated::<String>::from_err(anyhow::Error::msg("Test error 3".to_string()).into());
assert_eq!(format!("{}", annotated.err().unwrap()), "Test error 3");
assert!(!annotated.is_ok());
assert!(annotated.is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::error::Error;
pub trait MaybeError {
/// Construct an instance from an error.
fn from_err(err: Box<dyn Error>) -> Self;
/// Construct into an error instance.
fn err(&self) -> Option<Box<dyn Error>>;
/// Check if the current instance represents a success.
fn is_ok(&self) -> bool {
!self.is_err()
}
/// Check if the current instance represents an error.
fn is_err(&self) -> bool {
self.err().is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestError {
message: String,
}
impl MaybeError for TestError {
fn from_err(err: Box<dyn Error>) -> Self {
TestError {
message: err.to_string(),
}
}
fn err(&self) -> Option<Box<dyn Error>> {
Some(anyhow::Error::msg(self.message.clone()).into())
}
}
#[test]
fn test_maybe_error_default_implementations() {
let err = TestError::from_err(anyhow::Error::msg("Test error".to_string()).into());
assert_eq!(format!("{}", err.err().unwrap()), "Test error");
assert!(!err.is_ok());
assert!(err.is_err());
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment