Unverified Commit b4ddca99 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Failure Detection while Responses are returning (#1671)

parent bd91dca6
...@@ -214,7 +214,7 @@ struct Endpoint { ...@@ -214,7 +214,7 @@ struct Endpoint {
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
struct Client { struct Client {
router: rs::pipeline::PushRouter<serde_json::Value, serde_json::Value>, router: rs::pipeline::PushRouter<serde_json::Value, RsAnnotated<serde_json::Value>>,
} }
#[pyclass(eq, eq_int)] #[pyclass(eq, eq_int)]
...@@ -485,13 +485,12 @@ impl Endpoint { ...@@ -485,13 +485,12 @@ impl Endpoint {
let inner = self.inner.clone(); let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let client = inner.client().await.map_err(to_pyerr)?; let client = inner.client().await.map_err(to_pyerr)?;
let push_router = let push_router = rs::pipeline::PushRouter::<
rs::pipeline::PushRouter::<serde_json::Value, serde_json::Value>::from_client( serde_json::Value,
client, RsAnnotated<serde_json::Value>,
Default::default(), >::from_client(client, Default::default())
) .await
.await .map_err(to_pyerr)?;
.map_err(to_pyerr)?;
Ok(Client { Ok(Client {
router: push_router, router: push_router,
}) })
...@@ -757,23 +756,13 @@ impl Client { ...@@ -757,23 +756,13 @@ impl Client {
} }
async fn process_stream( async fn process_stream(
stream: EngineStream<serde_json::Value>, stream: EngineStream<RsAnnotated<serde_json::Value>>,
tx: tokio::sync::mpsc::Sender<RsAnnotated<PyObject>>, tx: tokio::sync::mpsc::Sender<RsAnnotated<PyObject>>,
) { ) {
let mut stream = stream; let mut stream = stream;
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
// Convert the response to a PyObject using Python's GIL // Convert the response to a PyObject using Python's GIL
// TODO: Remove the clone, but still log the full JSON string on error. But how? let annotated: RsAnnotated<serde_json::Value> = response;
let annotated: RsAnnotated<serde_json::Value> = match serde_json::from_value(
response.clone(),
) {
Ok(a) => a,
Err(err) => {
tracing::error!(%err, %response, "process_stream: Failed de-serializing JSON into RsAnnotated");
break;
}
};
let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| { let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| {
let result = Python::with_gil(|py| match pythonize::pythonize(py, &data) { let result = Python::with_gil(|py| match pythonize::pythonize(py, &data) {
Ok(pyobj) => Ok(pyobj.into()), Ok(pyobj) => Ok(pyobj.into()),
......
...@@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize}; ...@@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
pub use super::preprocessor::PreprocessedRequest; pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason; pub use super::FinishReason;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>; pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>; pub type LogProbs = Vec<f64>;
...@@ -134,6 +135,20 @@ impl LLMEngineOutput { ...@@ -134,6 +135,20 @@ impl LLMEngineOutput {
} }
} }
impl MaybeError for LLMEngineOutput {
fn from_err(err: Box<dyn std::error::Error>) -> Self {
LLMEngineOutput::error(format!("{:?}", err))
}
fn err(&self) -> Option<Box<dyn std::error::Error>> {
if let Some(FinishReason::Error(err_msg)) = &self.finish_reason {
Some(anyhow::Error::msg(err_msg.clone()).into())
} else {
None
}
}
}
/// Raw output from embedding engines containing embedding vectors /// Raw output from embedding engines containing embedding vectors
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct EmbeddingsEngineOutput { pub struct EmbeddingsEngineOutput {
...@@ -144,3 +159,26 @@ pub struct EmbeddingsEngineOutput { ...@@ -144,3 +159,26 @@ pub struct EmbeddingsEngineOutput {
pub prompt_tokens: u32, pub prompt_tokens: u32,
pub total_tokens: u32, pub total_tokens: u32,
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_maybe_error() {
let output = LLMEngineOutput::stop();
assert!(output.err().is_none());
assert!(output.is_ok());
assert!(!output.is_err());
let output = LLMEngineOutput::error("Test error".to_string());
assert_eq!(format!("{}", output.err().unwrap()), "Test error");
assert!(!output.is_ok());
assert!(output.is_err());
let output = LLMEngineOutput::from_err(anyhow::Error::msg("Test error 2").into());
assert_eq!(format!("{}", output.err().unwrap()), "Test error 2");
assert!(!output.is_ok());
assert!(output.is_err());
}
}
...@@ -323,3 +323,54 @@ impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> { ...@@ -323,3 +323,54 @@ impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> {
pub trait PushWorkHandler: Send + Sync { pub trait PushWorkHandler: Send + Sync {
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>; async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;
} }
/*
/// `NetworkStreamWrapper` is a simple wrapper used to detect proper stream termination
/// in network communication between ingress and egress components.
///
/// **Purpose**: This wrapper solves the problem of detecting whether a stream ended
/// gracefully or was cut off prematurely (e.g., due to network issues).
///
/// **Design Rationale**:
/// - Cannot use `Annotated` directly because the generic type `U` varies:
/// - Sometimes `U = Annotated<...>`
/// - Sometimes `U = LLMEngineOutput<...>`
/// - Using `Annotated` would require double-wrapping like `Annotated<Annotated<...>>`
/// - A simple wrapper is cleaner and more straightforward
///
/// **Stream Flow**:
/// ```
/// At AsyncEngine:
/// response 1 -> response 2 -> response 3 -> <end>
///
/// Between ingress/egress:
/// response 1 <end=false> -> response 2 <end=false> -> response 3 <end=false> -> (null) <end=true>
///
/// At client:
/// response 1 -> response 2 -> response 3 -> <end>
/// ```
///
/// **Error Handling**:
/// If the stream is cut off before proper termination, the egress is responsible for
/// injecting an error response to communicate the incomplete stream to the client:
/// ```
/// At AsyncEngine:
/// response 1 -> ... <without end flag>
///
/// At egress:
/// response 1 <end=false> -> <stream ended without end flag -> convert to error>
///
/// At client:
/// response 1 -> error response
/// ```
///
/// The detection must be done at egress level because premature stream termination
/// can be due to network issues that only the egress component can detect.
*/
/// TODO: Detect end-of-stream using Server-Sent Events (SSE). This will be removed.
#[derive(Serialize, Deserialize, Debug)]
pub struct NetworkStreamWrapper<U> {
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<U>,
pub complete_final: bool,
}
...@@ -17,7 +17,8 @@ use async_nats::client::Client; ...@@ -17,7 +17,8 @@ use async_nats::client::Client;
use tracing as log; use tracing as log;
use super::*; use super::*;
use crate::Result; use crate::{protocols::maybe_error::MaybeError, Result};
use tokio_stream::{wrappers::ReceiverStream, StreamExt, StreamNotifyClose};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -80,7 +81,7 @@ impl AddressedPushRouter { ...@@ -80,7 +81,7 @@ impl AddressedPushRouter {
impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
where where
T: Data + Serialize, T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>, U: Data + for<'de> Deserialize<'de> + MaybeError,
{ {
async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> { async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
let request_id = request.context().id().to_string(); let request_id = request.context().id().to_string();
...@@ -160,16 +161,49 @@ where ...@@ -160,16 +161,49 @@ where
.map_err(|_| PipelineError::DetatchedStreamReceiver)? .map_err(|_| PipelineError::DetatchedStreamReceiver)?
.map_err(PipelineError::ConnectionFailed)?; .map_err(PipelineError::ConnectionFailed)?;
let stream = tokio_stream::wrappers::ReceiverStream::new(response_stream.rx); // TODO: Detect end-of-stream using Server-Sent Events (SSE)
let mut is_complete_final = false;
let stream = stream.filter_map(|msg| async move { let stream = tokio_stream::StreamNotifyClose::new(
match serde_json::from_slice::<U>(&msg) { tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
Ok(r) => Some(r), )
Err(err) => { .filter_map(move |res| {
let json_str = String::from_utf8_lossy(&msg); if let Some(res_bytes) = res {
log::warn!(%err, %json_str, "Failed deserializing JSON to response"); if is_complete_final {
None return Some(U::from_err(
Error::msg(
"Response received after generation ended - this should never happen",
)
.into(),
));
}
match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
Ok(item) => {
is_complete_final = item.complete_final;
if let Some(data) = item.data {
Some(data)
} else if is_complete_final {
None
} else {
Some(U::from_err(
Error::msg("Empty response received - this should never happen")
.into(),
))
}
}
Err(err) => {
// legacy log print
let json_str = String::from_utf8_lossy(&res_bytes);
log::warn!(%err, %json_str, "Failed deserializing JSON to response");
Some(U::from_err(Error::new(err).into()))
}
} }
} else if is_complete_final {
None
} else {
Some(U::from_err(
Error::msg("Stream ended before generation completed").into(),
))
} }
}); });
......
...@@ -13,6 +13,16 @@ ...@@ -13,6 +13,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::{AsyncEngineContextProvider, ResponseStream};
use crate::{
component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data},
pipeline::{
error::PipelineErrorExt, AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
},
protocols::maybe_error::MaybeError,
traits::DistributedRuntimeProvider,
};
use async_nats::client::{ use async_nats::client::{
RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders, RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
}; };
...@@ -27,15 +37,7 @@ use std::{ ...@@ -27,15 +37,7 @@ use std::{
Arc, Arc,
}, },
}; };
use tokio_stream::StreamExt;
use crate::{
component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data},
pipeline::{
error::PipelineErrorExt, AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
},
traits::DistributedRuntimeProvider,
};
#[derive(Clone)] #[derive(Clone)]
pub struct PushRouter<T, U> pub struct PushRouter<T, U>
...@@ -94,7 +96,7 @@ async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPu ...@@ -94,7 +96,7 @@ async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPu
impl<T, U> PushRouter<T, U> impl<T, U> PushRouter<T, U>
where where
T: Data + Serialize, T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>, U: Data + for<'de> Deserialize<'de> + MaybeError,
{ {
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> { pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?; let addressed = addressed_router(&client.endpoint).await?;
...@@ -109,51 +111,44 @@ where ...@@ -109,51 +111,44 @@ where
/// Issue a request to the next available instance in a round-robin fashion /// Issue a request to the next available instance in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> { pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let slf = self; let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
let routing_algorithm = move || async move {
let counter = slf.round_robin_counter.fetch_add(1, Ordering::Relaxed);
let instance_id = {
let instances = slf.client.instances_avail().await;
let count = instances.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no instances found for endpoint {:?}",
slf.client.endpoint.etcd_root()
));
}
let offset = counter % count as u64;
instances[offset as usize].id()
};
tracing::trace!("round robin router selected {instance_id}");
Ok(instance_id) let instance_id = {
let instances = self.client.instances_avail().await;
let count = instances.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no instances found for endpoint {:?}",
self.client.endpoint.etcd_root()
));
}
let offset = counter % count as u64;
instances[offset as usize].id()
}; };
self.generate_with_fault_tolerance(routing_algorithm, request) tracing::trace!("round robin router selected {instance_id}");
self.generate_with_fault_detection(instance_id, request)
.await .await
} }
/// Issue a request to a random endpoint /// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> { pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let slf = self; let instance_id = {
let routing_algorithm = move || async move { let instances = self.client.instances_avail().await;
let instance_id = { let count = instances.len();
let instances = slf.client.instances_avail().await; if count == 0 {
let count = instances.len(); return Err(anyhow::anyhow!(
if count == 0 { "no instances found for endpoint {:?}",
return Err(anyhow::anyhow!( self.client.endpoint.etcd_root()
"no instances found for endpoint {:?}", ));
slf.client.endpoint.etcd_root() }
)); let counter = rand::rng().random::<u64>();
} let offset = counter % count as u64;
let counter = rand::rng().random::<u64>(); instances[offset as usize].id()
let offset = counter % count as u64;
instances[offset as usize].id()
};
tracing::trace!("random router selected {instance_id}");
Ok(instance_id)
}; };
self.generate_with_fault_tolerance(routing_algorithm, request) tracing::trace!("random router selected {instance_id}");
self.generate_with_fault_detection(instance_id, request)
.await .await
} }
...@@ -163,22 +158,19 @@ where ...@@ -163,22 +158,19 @@ where
request: SingleIn<T>, request: SingleIn<T>,
instance_id: i64, instance_id: i64,
) -> anyhow::Result<ManyOut<U>> { ) -> anyhow::Result<ManyOut<U>> {
let slf = self; let found = {
let routing_algorithm = move || async move { let instances = self.client.instances_avail().await;
let found = { instances.iter().any(|ep| ep.id() == instance_id)
let instances = slf.client.instances_avail().await;
instances.iter().any(|ep| ep.id() == instance_id)
};
if !found {
return Err(anyhow::anyhow!(
"instance_id={instance_id} not found for endpoint {:?}",
slf.client.endpoint.etcd_root()
));
}
Ok(instance_id)
}; };
self.generate_with_fault_tolerance(routing_algorithm, request)
if !found {
return Err(anyhow::anyhow!(
"instance_id={instance_id} not found for endpoint {:?}",
self.client.endpoint.etcd_root()
));
}
self.generate_with_fault_detection(instance_id, request)
.await .await
} }
...@@ -190,29 +182,45 @@ where ...@@ -190,29 +182,45 @@ where
self.addressed.generate(request).await self.addressed.generate(request).await
} }
async fn generate_with_fault_tolerance<F, R>( async fn generate_with_fault_detection(
&self, &self,
routing_algorithm: F, instance_id: i64,
request: SingleIn<T>, request: SingleIn<T>,
) -> anyhow::Result<ManyOut<U>> ) -> anyhow::Result<ManyOut<U>> {
where
F: FnOnce() -> R,
R: Future<Output = anyhow::Result<i64>>,
{
let instance_id = routing_algorithm().await?;
let subject = self.client.endpoint.subject_to(instance_id); let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject)); let request = request.map(|req| AddressedRequest::new(req, subject));
let stream = self.addressed.generate(request).await; let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
if let Some(err) = stream.as_ref().err() { match stream {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() { Ok(stream) => {
if matches!(req_err.kind(), NatsNoResponders) { let engine_ctx = stream.context();
self.client.report_instance_down(instance_id).await; let client = self.client.clone();
let stream = stream.then(move |res| {
let mut report_instance_down: Option<(Client, i64)> = None;
if let Some(err) = res.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG {
report_instance_down = Some((client.clone(), instance_id));
}
}
async move {
if let Some((client, instance_id)) = report_instance_down {
client.report_instance_down(instance_id).await;
}
res
}
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
Err(err) => {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
if matches!(req_err.kind(), NatsNoResponders) {
self.client.report_instance_down(instance_id).await;
}
} }
Err(err)
} }
} }
stream
} }
} }
...@@ -220,7 +228,7 @@ where ...@@ -220,7 +228,7 @@ where
impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U> impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
where where
T: Data + Serialize, T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>, U: Data + for<'de> Deserialize<'de> + MaybeError,
{ {
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> { async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match self.client.instance_source.as_ref() { match self.client.instance_source.as_ref() {
......
...@@ -97,16 +97,37 @@ where ...@@ -97,16 +97,37 @@ where
let context = stream.context(); let context = stream.context();
// TODO: Detect end-of-stream using Server-Sent Events (SSE)
let mut send_complete_final = true;
while let Some(resp) = stream.next().await { while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp); tracing::trace!("Sending response: {:?}", resp);
let resp_bytes = serde_json::to_vec(&resp) let resp_wrapper = NetworkStreamWrapper {
data: Some(resp),
complete_final: false,
};
let resp_bytes = serde_json::to_vec(&resp_wrapper)
.expect("fatal error: invalid response object - this should never happen"); .expect("fatal error: invalid response object - this should never happen");
if (publisher.send(resp_bytes.into()).await).is_err() { if (publisher.send(resp_bytes.into()).await).is_err() {
tracing::error!("Failed to publish response for stream {}", context.id()); tracing::error!("Failed to publish response for stream {}", context.id());
context.stop_generating(); context.stop_generating();
send_complete_final = false;
break; break;
} }
} }
if send_complete_final {
let resp_wrapper = NetworkStreamWrapper::<U> {
data: None,
complete_final: true,
};
let resp_bytes = serde_json::to_vec(&resp_wrapper)
.expect("fatal error: invalid response object - this should never happen");
if (publisher.send(resp_bytes.into()).await).is_err() {
tracing::error!(
"Failed to publish complete final for stream {}",
context.id()
);
}
}
Ok(()) Ok(())
} }
......
...@@ -19,6 +19,7 @@ use std::str::FromStr; ...@@ -19,6 +19,7 @@ use std::str::FromStr;
use crate::pipeline::PipelineError; use crate::pipeline::PipelineError;
pub mod annotated; pub mod annotated;
pub mod maybe_error;
pub type LeaseId = i64; pub type LeaseId = i64;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
use super::*; use super::*;
use crate::{error, Result}; use crate::{error, Result};
use maybe_error::MaybeError;
pub trait AnnotationsProvider { pub trait AnnotationsProvider {
fn annotations(&self) -> Option<Vec<String>>; fn annotations(&self) -> Option<Vec<String>>;
...@@ -28,7 +29,7 @@ pub trait AnnotationsProvider { ...@@ -28,7 +29,7 @@ pub trait AnnotationsProvider {
/// Our services have the option of returning an "annotated" stream, which allows use /// Our services have the option of returning an "annotated" stream, which allows use
/// to include additional information with each delta. This is useful for debugging, /// to include additional information with each delta. This is useful for debugging,
/// performance benchmarking, and improved observability. /// performance benchmarking, and improved observability.
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Annotated<R> { pub struct Annotated<R> {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<R>, pub data: Option<R>,
...@@ -146,6 +147,28 @@ impl<R> Annotated<R> { ...@@ -146,6 +147,28 @@ impl<R> Annotated<R> {
} }
} }
impl<R> MaybeError for Annotated<R>
where
R: for<'de> Deserialize<'de> + Serialize,
{
fn from_err(err: Box<dyn std::error::Error>) -> Self {
Annotated::from_error(format!("{:?}", err))
}
fn err(&self) -> Option<Box<dyn std::error::Error>> {
if self.is_error() {
if let Some(comment) = &self.comment {
if !comment.is_empty() {
return Some(anyhow::Error::msg(comment.join("; ")).into());
}
}
Some(anyhow::Error::msg("unknown error").into())
} else {
None
}
}
}
// impl<R> Annotated<R> // impl<R> Annotated<R>
// where // where
// R: for<'de> Deserialize<'de> + Serialize, // R: for<'de> Deserialize<'de> + Serialize,
...@@ -166,3 +189,27 @@ impl<R> Annotated<R> { ...@@ -166,3 +189,27 @@ impl<R> Annotated<R> {
// Box::pin(stream) // Box::pin(stream)
// } // }
// } // }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_maybe_error() {
let annotated = Annotated::from_data("Test data".to_string());
assert!(annotated.err().is_none());
assert!(annotated.is_ok());
assert!(!annotated.is_err());
let annotated = Annotated::<String>::from_error("Test error 2".to_string());
assert_eq!(format!("{}", annotated.err().unwrap()), "Test error 2");
assert!(!annotated.is_ok());
assert!(annotated.is_err());
let annotated =
Annotated::<String>::from_err(anyhow::Error::msg("Test error 3".to_string()).into());
assert_eq!(format!("{}", annotated.err().unwrap()), "Test error 3");
assert!(!annotated.is_ok());
assert!(annotated.is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::error::Error;
pub trait MaybeError {
/// Construct an instance from an error.
fn from_err(err: Box<dyn Error>) -> Self;
/// Construct into an error instance.
fn err(&self) -> Option<Box<dyn Error>>;
/// Check if the current instance represents a success.
fn is_ok(&self) -> bool {
!self.is_err()
}
/// Check if the current instance represents an error.
fn is_err(&self) -> bool {
self.err().is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestError {
message: String,
}
impl MaybeError for TestError {
fn from_err(err: Box<dyn Error>) -> Self {
TestError {
message: err.to_string(),
}
}
fn err(&self) -> Option<Box<dyn Error>> {
Some(anyhow::Error::msg(self.message.clone()).into())
}
}
#[test]
fn test_maybe_error_default_implementations() {
let err = TestError::from_err(anyhow::Error::msg("Test error".to_string()).into());
assert_eq!(format!("{}", err.err().unwrap()), "Test error");
assert!(!err.is_ok());
assert!(err.is_err());
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment