Unverified Commit c8276cd2 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Standardized Dynamo Error Type (#6303)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent eb76a8b5
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Dynamo Error System
//!
//! This module provides a standardized error type for Dynamo with support for:
//! - Categorized error types via [`ErrorType`] enum
//! - Error chaining via the standard [`std::error::Error::source()`] method
//! - Serialization for network transmission via serde
//!
//! # DynamoError
//!
//! [`DynamoError`] is the standardized error type for Dynamo. It can be created
//! directly or converted from any [`std::error::Error`]:
//!
//! ```rust,ignore
//! use dynamo_runtime::error::{DynamoError, ErrorType};
//!
//! // Simple error
//! let err = DynamoError::msg("something failed");
//!
//! // Typed error with cause
//! let cause = std::io::Error::other("io error");
//! let err = DynamoError::builder()
//! .error_type(ErrorType::Unknown)
//! .message("operation failed")
//! .cause(cause)
//! .build();
//!
//! // Convert from any std::error::Error
//! let std_err = std::io::Error::other("io error");
//! let dynamo_err = DynamoError::from(Box::new(std_err) as Box<dyn std::error::Error>);
//! ```
use serde::{Deserialize, Serialize};
use std::fmt;
// ============================================================================
// ErrorType Enum
// ============================================================================
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ErrorType {
/// Uncategorized or unknown error.
Unknown,
/// Failed to establish a connection to a remote worker.
CannotConnect,
/// An established connection was lost unexpectedly.
Disconnected,
/// A connection or request timed out.
ConnectionTimeout,
/// Error originating from a backend engine.
Backend(BackendError),
}
impl fmt::Display for ErrorType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ErrorType::Unknown => write!(f, "Unknown"),
ErrorType::CannotConnect => write!(f, "CannotConnect"),
ErrorType::Disconnected => write!(f, "Disconnected"),
ErrorType::ConnectionTimeout => write!(f, "ConnectionTimeout"),
ErrorType::Backend(sub) => write!(f, "Backend.{sub}"),
}
}
}
/// Categorizes errors into a fixed set of standard types.
///
/// Consumers (e.g., the migration module) inspect the error type to decide
/// what action to take, rather than the error defining its own behavior.
/// Backend engine error subcategories.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BackendError {
/// The engine process has shut down or crashed.
EngineShutdown,
}
impl fmt::Display for BackendError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BackendError::EngineShutdown => write!(f, "EngineShutdown"),
}
}
}
// ============================================================================
// DynamoError - The Standardized Error Type
// ============================================================================
/// The standardized error type for Dynamo.
///
/// `DynamoError` is a serializable, chainable error that:
/// - Carries an [`ErrorType`] for categorization
/// - Supports error chaining via [`std::error::Error::source()`]
/// - Is serializable for network transmission via `Annotated`
/// - Can be created from any [`std::error::Error`]
///
/// # Display
///
/// `Display` shows only the current error (standard Rust convention).
/// Use `source()` to walk the cause chain:
///
/// ```rust,ignore
/// let err = DynamoError::msg("outer");
/// println!("{}", err); // "Unknown: outer"
/// ```
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DynamoError {
error_type: ErrorType,
message: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
caused_by: Option<Box<DynamoError>>,
}
impl DynamoError {
/// Create a builder for constructing a `DynamoError`.
pub fn builder() -> DynamoErrorBuilder {
DynamoErrorBuilder::default()
}
/// Shorthand to create an `Unknown` error with just a message and no cause.
pub fn msg(message: impl Into<String>) -> Self {
Self::builder().message(message).build()
}
/// Returns the error type.
pub fn error_type(&self) -> ErrorType {
self.error_type
}
/// Returns the error message.
pub fn message(&self) -> &str {
&self.message
}
}
impl fmt::Display for DynamoError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.error_type, self.message)
}
}
impl std::error::Error for DynamoError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.caused_by
.as_deref()
.map(|e| e as &(dyn std::error::Error + 'static))
}
}
/// Convert from a reference to any `std::error::Error`.
///
/// If the error is already a `DynamoError`, it is cloned. Otherwise, it is
/// wrapped as `ErrorType::Unknown` with the display string as the message.
/// The source chain is recursively converted, preserving `DynamoError` instances.
impl<'a> From<&'a (dyn std::error::Error + 'static)> for DynamoError {
fn from(err: &'a (dyn std::error::Error + 'static)) -> Self {
if let Some(dynamo_err) = err.downcast_ref::<DynamoError>() {
return dynamo_err.clone();
}
Self {
error_type: ErrorType::Unknown,
message: err.to_string(),
caused_by: err.source().map(|s| Box::new(DynamoError::from(s))),
}
}
}
/// Convert from an owned boxed `std::error::Error`.
///
/// If the error is already a `DynamoError`, ownership is taken without cloning.
/// Otherwise, falls back to the reference-based conversion.
impl From<Box<dyn std::error::Error + 'static>> for DynamoError {
fn from(err: Box<dyn std::error::Error + 'static>) -> Self {
match err.downcast::<DynamoError>() {
Ok(dynamo_err) => *dynamo_err,
Err(err) => DynamoError::from(&*err as &(dyn std::error::Error + 'static)),
}
}
}
// ============================================================================
// DynamoErrorBuilder
// ============================================================================
/// Builder for constructing a [`DynamoError`].
///
/// # Example
/// ```rust,ignore
/// let err = DynamoError::builder()
/// .error_type(ErrorType::Disconnected)
/// .message("worker lost")
/// .cause(some_io_error)
/// .build();
/// ```
#[derive(Default)]
pub struct DynamoErrorBuilder {
error_type: Option<ErrorType>,
message: Option<String>,
caused_by: Option<Box<DynamoError>>,
}
impl DynamoErrorBuilder {
/// Set the error type.
pub fn error_type(mut self, error_type: ErrorType) -> Self {
self.error_type = Some(error_type);
self
}
/// Set the error message.
pub fn message(mut self, message: impl Into<String>) -> Self {
self.message = Some(message.into());
self
}
/// Set the cause from any `std::error::Error`.
///
/// If the cause is already a `DynamoError`, it is preserved as-is.
/// Otherwise, it is converted to a `DynamoError` with `ErrorType::Unknown`.
pub fn cause(mut self, cause: impl std::error::Error + 'static) -> Self {
self.caused_by = Some(Box::new(DynamoError::from(
&cause as &(dyn std::error::Error + 'static),
)));
self
}
/// Build the `DynamoError`.
///
/// Defaults: `error_type` → `Unknown`, `message` → `""`, `cause` → `None`.
pub fn build(self) -> DynamoError {
DynamoError {
error_type: self.error_type.unwrap_or(ErrorType::Unknown),
message: self.message.unwrap_or_default(),
caused_by: self.caused_by,
}
}
}
// ============================================================================
// Utility Functions
// ============================================================================
/// Check whether an error chain contains a specific set of error types
/// while not containing any of the excluded error types.
///
/// Walks the chain via `source()`, inspecting each error that can be downcast
/// to `DynamoError`. Returns `false` immediately if any error's type is in
/// `exclude_set`. Otherwise, returns `true` if at least one error's type is
/// in `match_set`. Errors that are not `DynamoError` are skipped.
pub fn match_error_chain(
err: &(dyn std::error::Error + 'static),
match_set: &[ErrorType],
exclude_set: &[ErrorType],
) -> bool {
let mut found = false;
let mut current: Option<&(dyn std::error::Error + 'static)> = Some(err);
while let Some(e) = current {
if let Some(dynamo_err) = e.downcast_ref::<DynamoError>() {
if exclude_set.contains(&dynamo_err.error_type()) {
return false;
}
if match_set.contains(&dynamo_err.error_type()) {
found = true;
}
}
current = e.source();
}
found
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error;
// Compile-time assertions that DynamoError is std::error::Error + Send + Sync + 'static.
// These fail at compile time if a future change breaks these guarantees.
const _: () = {
fn assert_stderror<T: std::error::Error>() {}
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
fn assert_static<T: 'static>() {}
fn assert_all() {
assert_stderror::<DynamoError>();
assert_send::<DynamoError>();
assert_sync::<DynamoError>();
assert_static::<DynamoError>();
}
};
#[test]
fn test_msg_constructor() {
let err = DynamoError::msg("something failed");
assert_eq!(err.error_type(), ErrorType::Unknown);
assert_eq!(err.message(), "something failed");
assert!(err.source().is_none());
}
#[test]
fn test_new_constructor_with_cause() {
let cause = std::io::Error::other("io error");
let err = DynamoError::builder()
.error_type(ErrorType::Unknown)
.message("operation failed")
.cause(cause)
.build();
assert_eq!(err.error_type(), ErrorType::Unknown);
assert_eq!(err.message(), "operation failed");
assert!(err.source().is_some());
}
#[test]
fn test_display_shows_only_current_error() {
let cause = std::io::Error::other("io error");
let err = DynamoError::builder()
.error_type(ErrorType::Unknown)
.message("operation failed")
.cause(cause)
.build();
// Display should only show the current error, not the chain
assert_eq!(err.to_string(), "Unknown: operation failed");
}
#[test]
fn test_source_chain() {
let cause = std::io::Error::other("io error");
let err = DynamoError::builder()
.error_type(ErrorType::Unknown)
.message("operation failed")
.cause(cause)
.build();
// source() should return the cause
let source = err.source().unwrap();
assert!(source.to_string().contains("io error"));
}
#[test]
fn test_from_boxed_std_error() {
let std_err = std::io::Error::other("io error");
let boxed: Box<dyn std::error::Error> = Box::new(std_err);
let dynamo_err = DynamoError::from(boxed);
assert_eq!(dynamo_err.error_type(), ErrorType::Unknown);
assert_eq!(dynamo_err.message(), "io error");
}
#[test]
fn test_from_boxed_takes_ownership_of_dynamo_error() {
let inner = DynamoError::msg("original");
let boxed: Box<dyn std::error::Error> = Box::new(inner);
let dynamo_err = DynamoError::from(boxed);
// Should take ownership, not clone or wrap
assert_eq!(dynamo_err.error_type(), ErrorType::Unknown);
assert_eq!(dynamo_err.message(), "original");
}
#[test]
fn test_from_boxed_with_source_chain() {
#[derive(Debug)]
struct OuterError {
source: std::io::Error,
}
impl fmt::Display for OuterError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "outer error occurred")
}
}
impl std::error::Error for OuterError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.source)
}
}
let inner = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let outer = OuterError { source: inner };
let boxed: Box<dyn std::error::Error> = Box::new(outer);
let dynamo_err = DynamoError::from(boxed);
assert_eq!(dynamo_err.message(), "outer error occurred");
assert!(dynamo_err.source().is_some());
let cause = dynamo_err.source().unwrap();
assert!(cause.to_string().contains("file not found"));
}
#[test]
fn test_serialization_roundtrip() {
let cause = DynamoError::msg("inner cause");
let err = DynamoError::builder()
.error_type(ErrorType::Unknown)
.message("outer error")
.cause(cause)
.build();
let json = serde_json::to_string(&err).unwrap();
let deserialized: DynamoError = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.error_type(), ErrorType::Unknown);
assert_eq!(deserialized.message(), "outer error");
assert!(deserialized.source().is_some());
let cause = deserialized
.source()
.unwrap()
.downcast_ref::<DynamoError>()
.unwrap();
assert_eq!(cause.message(), "inner cause");
}
#[test]
fn test_error_type_display() {
assert_eq!(ErrorType::Unknown.to_string(), "Unknown");
}
}
...@@ -25,6 +25,7 @@ pub mod compute; ...@@ -25,6 +25,7 @@ pub mod compute;
pub mod discovery; pub mod discovery;
pub mod engine; pub mod engine;
pub mod engine_routes; pub mod engine_routes;
pub mod error;
pub mod health_check; pub mod health_check;
pub mod local_endpoint_registry; pub mod local_endpoint_registry;
pub mod system_status_server; pub mod system_status_server;
......
...@@ -33,9 +33,6 @@ use super::{ ...@@ -33,9 +33,6 @@ use super::{
}; };
use ingress::push_handler::WorkHandlerMetrics; use ingress::push_handler::WorkHandlerMetrics;
// Define stream error message constant
pub const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
// Add Prometheus metrics types // Add Prometheus metrics types
use crate::metrics::MetricsHierarchy; use crate::metrics::MetricsHierarchy;
use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge}; use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
......
...@@ -6,11 +6,11 @@ use std::sync::Arc; ...@@ -6,11 +6,11 @@ use std::sync::Arc;
use super::unified_client::RequestPlaneClient; use super::unified_client::RequestPlaneClient;
use super::*; use super::*;
use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data}; use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
use crate::error::{DynamoError, ErrorType};
use crate::logging::inject_trace_headers_into_map; use crate::logging::inject_trace_headers_into_map;
use crate::pipeline::network::ConnectionInfo; use crate::pipeline::network::ConnectionInfo;
use crate::pipeline::network::NetworkStreamWrapper; use crate::pipeline::network::NetworkStreamWrapper;
use crate::pipeline::network::PendingConnections; use crate::pipeline::network::PendingConnections;
use crate::pipeline::network::STREAM_ERR_MSG;
use crate::pipeline::network::StreamOptions; use crate::pipeline::network::StreamOptions;
use crate::pipeline::network::TwoPartCodec; use crate::pipeline::network::TwoPartCodec;
use crate::pipeline::network::codec::TwoPartMessage; use crate::pipeline::network::codec::TwoPartMessage;
...@@ -187,12 +187,10 @@ where ...@@ -187,12 +187,10 @@ where
.filter_map(move |res| { .filter_map(move |res| {
if let Some(res_bytes) = res { if let Some(res_bytes) = res {
if is_complete_final { if is_complete_final {
return Some(U::from_err( let err = DynamoError::msg(
Error::msg( "Response received after generation ended - this should never happen",
"Response received after generation ended - this should never happen", );
) return Some(U::from_err(err));
.into(),
));
} }
match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) { match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
Ok(item) => { Ok(item) => {
...@@ -202,10 +200,10 @@ where ...@@ -202,10 +200,10 @@ where
} else if is_complete_final { } else if is_complete_final {
None None
} else { } else {
Some(U::from_err( let err = DynamoError::msg(
Error::msg("Empty response received - this should never happen") "Empty response received - this should never happen",
.into(), );
)) Some(U::from_err(err))
} }
} }
Err(err) => { Err(err) => {
...@@ -213,7 +211,7 @@ where ...@@ -213,7 +211,7 @@ where
let json_str = String::from_utf8_lossy(&res_bytes); let json_str = String::from_utf8_lossy(&res_bytes);
tracing::warn!(%err, %json_str, "Failed deserializing JSON to response"); tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
Some(U::from_err(Error::new(err).into())) Some(U::from_err(DynamoError::msg(err.to_string())))
} }
} }
} else if is_complete_final { } else if is_complete_final {
...@@ -227,8 +225,12 @@ where ...@@ -227,8 +225,12 @@ where
None None
} else { } else {
// stream ended unexpectedly // stream ended unexpectedly
tracing::debug!("{STREAM_ERR_MSG}"); let err = DynamoError::builder()
Some(U::from_err(Error::msg(STREAM_ERR_MSG).into())) .error_type(ErrorType::Disconnected)
.message("Stream ended before generation completed")
.build();
tracing::debug!("{}", err);
Some(U::from_err(err))
} }
}); });
......
...@@ -179,7 +179,15 @@ impl RequestPlaneClient for HttpRequestClient { ...@@ -179,7 +179,15 @@ impl RequestPlaneClient for HttpRequestClient {
req = req.header(key, value); req = req.header(key, value);
} }
let response = req.send().await?; let response = req.send().await.map_err(|e| {
anyhow::anyhow!(
crate::error::DynamoError::builder()
.error_type(crate::error::ErrorType::CannotConnect)
.message(format!("HTTP request to {address} failed"))
.cause(e)
.build()
)
})?;
if !response.status().is_success() { if !response.status().is_success() {
anyhow::bail!( anyhow::bail!(
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
//! providing a consistent interface across all transport types. //! providing a consistent interface across all transport types.
use super::unified_client::{ClientStats, Headers, RequestPlaneClient}; use super::unified_client::{ClientStats, Headers, RequestPlaneClient};
use crate::error::{DynamoError, ErrorType};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
...@@ -47,9 +48,17 @@ impl RequestPlaneClient for NatsRequestClient { ...@@ -47,9 +48,17 @@ impl RequestPlaneClient for NatsRequestClient {
// Send request with headers // Send request with headers
let response = self let response = self
.client .client
.request_with_headers(address, nats_headers, payload) .request_with_headers(address.clone(), nats_headers, payload)
.await .await
.map_err(|e| anyhow::anyhow!("NATS request failed: {}", e))?; .map_err(|e| {
anyhow::anyhow!(
DynamoError::builder()
.error_type(ErrorType::CannotConnect)
.message(format!("NATS request to {address} failed"))
.cause(e)
.build()
)
})?;
Ok(response.payload) Ok(response.payload)
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{AsyncEngineContextProvider, ResponseStream, STREAM_ERR_MSG}; use super::{AsyncEngineContextProvider, ResponseStream};
use crate::error::{BackendError, ErrorType, match_error_chain};
/// Check if an error chain indicates the worker should be reported as down.
fn is_inhibited(err: &(dyn std::error::Error + 'static)) -> bool {
const INHIBITED: &[ErrorType] = &[
ErrorType::CannotConnect,
ErrorType::Disconnected,
ErrorType::ConnectionTimeout,
ErrorType::Backend(BackendError::EngineShutdown),
];
match_error_chain(err, INHIBITED, &[])
}
use crate::{ use crate::{
component::{Client, Endpoint}, component::{Client, Endpoint},
engine::{AsyncEngine, Data}, engine::{AsyncEngine, Data},
...@@ -12,9 +24,6 @@ use crate::{ ...@@ -12,9 +24,6 @@ use crate::{
protocols::maybe_error::MaybeError, protocols::maybe_error::MaybeError,
traits::DistributedRuntimeProvider, traits::DistributedRuntimeProvider,
}; };
use async_nats::client::{
RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
};
use async_trait::async_trait; use async_trait::async_trait;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -399,12 +408,12 @@ where ...@@ -399,12 +408,12 @@ where
let engine_ctx = stream.context(); let engine_ctx = stream.context();
let client = self.client.clone(); let client = self.client.clone();
let stream = stream.map(move |res| { let stream = stream.map(move |res| {
// TODO: Standardize error type to avoid using string matching DIS-364 // Check if the error is migratable (indicates worker/connection failure)
if let Some(err) = res.err() if let Some(err) = res.err()
&& format!("{:?}", err) == STREAM_ERR_MSG && is_inhibited(&err)
{ {
tracing::debug!( tracing::debug!(
"Reporting instance {instance_id} down due to stream error: {err}" "Reporting instance {instance_id} down due to migratable error: {err}"
); );
client.report_instance_down(instance_id); client.report_instance_down(instance_id);
} }
...@@ -413,13 +422,8 @@ where ...@@ -413,13 +422,8 @@ where
Ok(ResponseStream::new(Box::pin(stream), engine_ctx)) Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
} }
Err(err) => { Err(err) => {
if self.fault_detection_enabled if self.fault_detection_enabled && is_inhibited(err.as_ref()) {
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>() tracing::debug!("Reporting instance {instance_id} down due to error: {err}");
&& matches!(req_err.kind(), NatsNoResponders)
{
tracing::debug!(
"Reporting instance {instance_id} down due to request error: {req_err}"
);
self.client.report_instance_down(instance_id); self.client.report_instance_down(instance_id);
} }
Err(err) Err(err)
......
...@@ -545,13 +545,27 @@ impl RequestPlaneClient for TcpRequestClient { ...@@ -545,13 +545,27 @@ impl RequestPlaneClient for TcpRequestClient {
self.stats.errors.fetch_add(1, Ordering::Relaxed); self.stats.errors.fetch_add(1, Ordering::Relaxed);
tracing::warn!("TCP request failed to {}: {}", addr, e); tracing::warn!("TCP request failed to {}: {}", addr, e);
// Don't return unhealthy connection to pool, let it drop // Don't return unhealthy connection to pool, let it drop
Err(e) let cause = crate::error::DynamoError::from(
e.into_boxed_dyn_error() as Box<dyn std::error::Error + 'static>
);
Err(anyhow::anyhow!(
crate::error::DynamoError::builder()
.error_type(crate::error::ErrorType::CannotConnect)
.message(format!("TCP request to {addr} failed"))
.cause(cause)
.build()
))
} }
Err(_) => { Err(_) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed); self.stats.errors.fetch_add(1, Ordering::Relaxed);
tracing::warn!("TCP request timeout to {}", addr); tracing::warn!("TCP request timeout to {}", addr);
// Don't return timed-out connection to pool // Don't return timed-out connection to pool
Err(anyhow::anyhow!("TCP request timeout to {}", addr)) Err(anyhow::anyhow!(
crate::error::DynamoError::builder()
.error_type(crate::error::ErrorType::CannotConnect)
.message(format!("TCP request to {addr} timed out"))
.build()
))
} }
} }
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::*; use super::*;
use crate::metrics::prometheus_names::work_handler; use crate::metrics::prometheus_names::work_handler;
use crate::protocols::maybe_error::MaybeError; use crate::protocols::maybe_error::MaybeError;
use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge}; use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
...@@ -265,13 +266,6 @@ where ...@@ -265,13 +266,6 @@ where
let mut send_complete_final = true; let mut send_complete_final = true;
while let Some(resp) = stream.next().await { while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp); tracing::trace!("Sending response: {:?}", resp);
if let Some(err) = resp.err()
&& format!("{:?}", err) == STREAM_ERR_MSG
{
tracing::warn!(STREAM_ERR_MSG);
send_complete_final = false;
break;
}
let resp_wrapper = NetworkStreamWrapper { let resp_wrapper = NetworkStreamWrapper {
data: Some(resp), data: Some(resp),
complete_final: false, complete_final: false,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::maybe_error::MaybeError; use super::maybe_error::MaybeError;
use crate::error::DynamoError;
use anyhow::{Result, anyhow as error}; use anyhow::{Result, anyhow as error};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -27,16 +28,19 @@ pub struct Annotated<R> { ...@@ -27,16 +28,19 @@ pub struct Annotated<R> {
pub event: Option<String>, pub event: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub comment: Option<Vec<String>>, pub comment: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<DynamoError>,
} }
impl<R> Annotated<R> { impl<R> Annotated<R> {
/// Create a new annotated stream from the given error /// Create a new annotated stream from the given error string
pub fn from_error(error: String) -> Self { pub fn from_error(error: String) -> Self {
Self { Self {
data: None, data: None,
id: None, id: None,
event: Some("error".to_string()), event: Some("error".to_string()),
comment: Some(vec![error]), comment: None,
error: Some(DynamoError::msg(error)),
} }
} }
...@@ -47,6 +51,7 @@ impl<R> Annotated<R> { ...@@ -47,6 +51,7 @@ impl<R> Annotated<R> {
id: None, id: None,
event: None, event: None,
comment: None, comment: None,
error: None,
} }
} }
...@@ -62,15 +67,20 @@ impl<R> Annotated<R> { ...@@ -62,15 +67,20 @@ impl<R> Annotated<R> {
id: None, id: None,
event: Some(name.into()), event: Some(name.into()),
comment: Some(vec![serde_json::to_string(value)?]), comment: Some(vec![serde_json::to_string(value)?]),
error: None,
}) })
} }
/// Convert to a [`Result<Self, String>`] /// Convert to a [`Result<Self, String>`]
/// If [`Self::event`] is "error", return an error message(s) held by [`Self::comment`] /// If [`Self::event`] is "error", return an error message
pub fn ok(self) -> Result<Self, String> { pub fn ok(self) -> Result<Self, String> {
if let Some(event) = &self.event if let Some(event) = &self.event
&& event == "error" && event == "error"
{ {
// First check DynamoError, then fallback to comment
if let Some(ref err) = self.error {
return Err(err.to_string());
}
return Err(self return Err(self
.comment .comment
.unwrap_or(vec!["unknown error".to_string()]) .unwrap_or(vec!["unknown error".to_string()])
...@@ -97,6 +107,7 @@ impl<R> Annotated<R> { ...@@ -97,6 +107,7 @@ impl<R> Annotated<R> {
id: self.id, id: self.id,
event: self.event, event: self.event,
comment: self.comment, comment: self.comment,
error: self.error,
} }
} }
...@@ -112,6 +123,7 @@ impl<R> Annotated<R> { ...@@ -112,6 +123,7 @@ impl<R> Annotated<R> {
id: self.id, id: self.id,
event: self.event, event: self.event,
comment: self.comment, comment: self.comment,
error: self.error,
}, },
Err(e) => Annotated::from_error(e), Err(e) => Annotated::from_error(e),
} }
...@@ -125,11 +137,18 @@ impl<R> Annotated<R> { ...@@ -125,11 +137,18 @@ impl<R> Annotated<R> {
match self.data { match self.data {
Some(data) => Ok(Some(data)), Some(data) => Ok(Some(data)),
None => match self.event { None => match self.event {
Some(event) if event == "error" => Err(error!( Some(event) if event == "error" => {
self.comment // First check DynamoError, then fallback to comment
.unwrap_or(vec!["unknown error".to_string()]) if let Some(ref err) = self.error {
.join(", ") Err(error!("{}", err))?
))?, } else {
Err(error!(
self.comment
.unwrap_or(vec!["unknown error".to_string()])
.join(", ")
))?
}
}
_ => Ok(None), _ => Ok(None),
}, },
} }
...@@ -138,47 +157,40 @@ impl<R> Annotated<R> { ...@@ -138,47 +157,40 @@ impl<R> Annotated<R> {
impl<R> MaybeError for Annotated<R> impl<R> MaybeError for Annotated<R>
where where
R: for<'de> Deserialize<'de> + Serialize, R: for<'de> Deserialize<'de>,
{ {
fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self { fn from_err(err: impl std::error::Error + 'static) -> Self {
Annotated::from_error(format!("{:?}", err)) Self {
data: None,
id: None,
event: Some("error".to_string()),
comment: None,
error: Some(DynamoError::from(
Box::new(err) as Box<dyn std::error::Error + 'static>
)),
}
} }
fn err(&self) -> Option<anyhow::Error> { fn err(&self) -> Option<DynamoError> {
if self.is_error() { if self.is_error() {
// First check DynamoError field
if let Some(ref error) = self.error {
return Some(error.clone());
}
// Fallback to comment-based error
if let Some(comment) = &self.comment if let Some(comment) = &self.comment
&& !comment.is_empty() && !comment.is_empty()
{ {
return Some(anyhow::Error::msg(comment.join("; "))); return Some(DynamoError::msg(comment.join("; ")));
} }
Some(anyhow::Error::msg("unknown error")) Some(DynamoError::msg("unknown error"))
} else { } else {
None None
} }
} }
} }
// impl<R> Annotated<R>
// where
// R: for<'de> Deserialize<'de> + Serialize,
// {
// pub fn convert_sse_stream(
// stream: DataStream<Result<Message, SseCodecError>>,
// ) -> DataStream<Annotated<R>> {
// let stream = stream.map(|message| match message {
// Ok(message) => {
// let delta = Annotated::<R>::try_from(message);
// match delta {
// Ok(delta) => delta,
// Err(e) => Annotated::from_error(e.to_string()),
// }
// }
// Err(e) => Annotated::from_error(e.to_string()),
// });
// Box::pin(stream)
// }
// }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
...@@ -190,11 +202,69 @@ mod tests { ...@@ -190,11 +202,69 @@ mod tests {
assert!(annotated.is_ok()); assert!(annotated.is_ok());
let annotated = Annotated::<String>::from_error("Test error 2".to_string()); let annotated = Annotated::<String>::from_error("Test error 2".to_string());
assert_eq!(format!("{}", annotated.err().unwrap()), "Test error 2"); assert!(annotated.err().is_some());
assert!(annotated.is_err());
let dynamo_err = DynamoError::msg("Test error 3");
let annotated = Annotated::<String>::from_err(dynamo_err);
assert!(annotated.is_err()); assert!(annotated.is_err());
}
#[test]
fn test_from_err() {
let err = DynamoError::msg("connection lost");
let annotated = Annotated::<String>::from_err(err);
let annotated =
Annotated::<String>::from_err(anyhow::Error::msg("Test error 3".to_string()).into());
assert!(annotated.is_err()); assert!(annotated.is_err());
let err = annotated.err().unwrap();
assert!(err.to_string().contains("connection lost"));
}
#[test]
fn test_error_serialization() {
let err = DynamoError::msg("test error");
let annotated = Annotated::<String>::from_err(err);
// Serialize and deserialize
let json = serde_json::to_string(&annotated).unwrap();
let deserialized: Annotated<String> = serde_json::from_str(&json).unwrap();
assert!(deserialized.is_err());
assert!(
deserialized
.err()
.unwrap()
.to_string()
.contains("test error")
);
}
#[test]
fn test_transfer_preserves_error() {
let err = DynamoError::msg("request timed out");
let annotated = Annotated::<String>::from_err(err);
let transferred: Annotated<i32> = annotated.transfer(None);
assert!(transferred.err().is_some());
}
#[test]
fn test_ok_method() {
let err = DynamoError::msg("connection lost");
let annotated = Annotated::<String>::from_err(err);
let result = annotated.ok();
assert!(result.is_err());
assert!(result.unwrap_err().contains("connection lost"));
}
#[test]
fn test_into_result() {
let err = DynamoError::msg("connection lost");
let annotated = Annotated::<String>::from_err(err);
let result = annotated.into_result();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("connection lost"));
} }
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::error::Error; //! MaybeError trait for types that may contain error information.
//!
//! This module provides the `MaybeError` trait which allows types to represent
//! either successful data or error states. It integrates with the `DynamoError`
//! system to provide structured error handling.
use crate::error::DynamoError;
/// A trait for types that may contain error information.
///
/// This trait allows a type to represent either a successful value or an error state.
/// It integrates with `DynamoError` for structured error information.
///
/// # Example
///
/// ```rust,ignore
/// use dynamo_runtime::protocols::maybe_error::MaybeError;
/// use dynamo_runtime::error::DynamoError;
///
/// struct MyResponse {
/// data: Option<String>,
/// error: Option<DynamoError>,
/// }
///
/// impl MaybeError for MyResponse {
/// fn from_err(err: impl std::error::Error + 'static) -> Self {
/// MyResponse {
/// data: None,
/// error: Some(DynamoError::from(
/// Box::new(err) as Box<dyn std::error::Error + 'static>
/// )),
/// }
/// }
///
/// fn err(&self) -> Option<DynamoError> {
/// self.error.clone()
/// }
/// }
/// ```
pub trait MaybeError { pub trait MaybeError {
/// Construct an instance from an error. /// Construct an instance from an error.
fn from_err(err: Box<dyn Error + Send + Sync>) -> Self; ///
/// The error is converted to a `DynamoError` for serialization.
fn from_err(err: impl std::error::Error + 'static) -> Self;
/// Construct into an error instance. /// Get the error as a `DynamoError` if this represents an error state.
fn err(&self) -> Option<anyhow::Error>; ///
/// Returns `Some(DynamoError)` if this instance represents an error, `None` otherwise.
fn err(&self) -> Option<DynamoError>;
/// Check if the current instance represents a success. /// Check if the current instance represents a success.
fn is_ok(&self) -> bool { fn is_ok(&self) -> bool {
...@@ -26,24 +67,46 @@ mod tests { ...@@ -26,24 +67,46 @@ mod tests {
use super::*; use super::*;
struct TestError { struct TestError {
message: String, error: Option<DynamoError>,
} }
impl MaybeError for TestError { impl MaybeError for TestError {
fn from_err(err: Box<dyn Error + Send + Sync>) -> Self { fn from_err(err: impl std::error::Error + 'static) -> Self {
TestError { TestError {
message: err.to_string(), error: Some(DynamoError::from(
Box::new(err) as Box<dyn std::error::Error + 'static>
)),
} }
} }
fn err(&self) -> Option<anyhow::Error> {
Some(anyhow::Error::msg(self.message.clone())) fn err(&self) -> Option<DynamoError> {
self.error.clone()
} }
} }
#[test] #[test]
fn test_maybe_error_default_implementations() { fn test_maybe_error_default_implementations() {
let err = TestError::from_err(anyhow::Error::msg("Test error".to_string()).into()); let dynamo_err = DynamoError::msg("Test error");
assert_eq!(format!("{}", err.err().unwrap()), "Test error"); let err = TestError::from_err(dynamo_err);
assert!(err.err().unwrap().to_string().contains("Test error"));
assert!(!err.is_ok()); assert!(!err.is_ok());
assert!(err.is_err()); assert!(err.is_err());
} }
#[test]
fn test_from_std_error() {
let std_err = std::io::Error::other("io failure");
let test_err = TestError::from_err(std_err);
assert!(test_err.is_err());
assert!(test_err.err().unwrap().to_string().contains("io failure"));
}
#[test]
fn test_not_error() {
let test = TestError { error: None };
assert!(test.is_ok());
assert!(!test.is_err());
assert!(test.err().is_none());
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment