Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
b4ddca99
Unverified
Commit
b4ddca99
authored
Jul 07, 2025
by
Jacky
Committed by
GitHub
Jul 07, 2025
Browse files
feat: Failure Detection while Responses are returning (#1671)
parent
bd91dca6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
361 additions
and
111 deletions
+361
-111
lib/bindings/python/rust/lib.rs
lib/bindings/python/rust/lib.rs
+9
-20
lib/llm/src/protocols/common/llm_backend.rs
lib/llm/src/protocols/common/llm_backend.rs
+38
-0
lib/runtime/src/pipeline/network.rs
lib/runtime/src/pipeline/network.rs
+51
-0
lib/runtime/src/pipeline/network/egress/addressed_router.rs
lib/runtime/src/pipeline/network/egress/addressed_router.rs
+45
-11
lib/runtime/src/pipeline/network/egress/push_router.rs
lib/runtime/src/pipeline/network/egress/push_router.rs
+86
-78
lib/runtime/src/pipeline/network/ingress/push_handler.rs
lib/runtime/src/pipeline/network/ingress/push_handler.rs
+22
-1
lib/runtime/src/protocols.rs
lib/runtime/src/protocols.rs
+1
-0
lib/runtime/src/protocols/annotated.rs
lib/runtime/src/protocols/annotated.rs
+48
-1
lib/runtime/src/protocols/maybe_error.rs
lib/runtime/src/protocols/maybe_error.rs
+61
-0
No files found.
lib/bindings/python/rust/lib.rs
View file @
b4ddca99
...
@@ -214,7 +214,7 @@ struct Endpoint {
...
@@ -214,7 +214,7 @@ struct Endpoint {
#[pyclass]
#[pyclass]
#[derive(Clone)]
#[derive(Clone)]
struct
Client
{
struct
Client
{
router
:
rs
::
pipeline
::
PushRouter
<
serde_json
::
Value
,
serde_json
::
Value
>
,
router
:
rs
::
pipeline
::
PushRouter
<
serde_json
::
Value
,
RsAnnotated
<
serde_json
::
Value
>
>
,
}
}
#[pyclass(eq,
eq_int)]
#[pyclass(eq,
eq_int)]
...
@@ -485,13 +485,12 @@ impl Endpoint {
...
@@ -485,13 +485,12 @@ impl Endpoint {
let
inner
=
self
.inner
.clone
();
let
inner
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
let
client
=
inner
.client
()
.await
.map_err
(
to_pyerr
)
?
;
let
client
=
inner
.client
()
.await
.map_err
(
to_pyerr
)
?
;
let
push_router
=
let
push_router
=
rs
::
pipeline
::
PushRouter
::
<
rs
::
pipeline
::
PushRouter
::
<
serde_json
::
Value
,
serde_json
::
Value
>
::
from_client
(
serde_json
::
Value
,
client
,
RsAnnotated
<
serde_json
::
Value
>
,
Default
::
default
(),
>
::
from_client
(
client
,
Default
::
default
())
)
.await
.await
.map_err
(
to_pyerr
)
?
;
.map_err
(
to_pyerr
)
?
;
Ok
(
Client
{
Ok
(
Client
{
router
:
push_router
,
router
:
push_router
,
})
})
...
@@ -757,23 +756,13 @@ impl Client {
...
@@ -757,23 +756,13 @@ impl Client {
}
}
async
fn
process_stream
(
async
fn
process_stream
(
stream
:
EngineStream
<
serde_json
::
Value
>
,
stream
:
EngineStream
<
RsAnnotated
<
serde_json
::
Value
>
>
,
tx
:
tokio
::
sync
::
mpsc
::
Sender
<
RsAnnotated
<
PyObject
>>
,
tx
:
tokio
::
sync
::
mpsc
::
Sender
<
RsAnnotated
<
PyObject
>>
,
)
{
)
{
let
mut
stream
=
stream
;
let
mut
stream
=
stream
;
while
let
Some
(
response
)
=
stream
.next
()
.await
{
while
let
Some
(
response
)
=
stream
.next
()
.await
{
// Convert the response to a PyObject using Python's GIL
// Convert the response to a PyObject using Python's GIL
// TODO: Remove the clone, but still log the full JSON string on error. But how?
let
annotated
:
RsAnnotated
<
serde_json
::
Value
>
=
response
;
let
annotated
:
RsAnnotated
<
serde_json
::
Value
>
=
match
serde_json
::
from_value
(
response
.clone
(),
)
{
Ok
(
a
)
=>
a
,
Err
(
err
)
=>
{
tracing
::
error!
(
%
err
,
%
response
,
"process_stream: Failed de-serializing JSON into RsAnnotated"
);
break
;
}
};
let
annotated
:
RsAnnotated
<
PyObject
>
=
annotated
.map_data
(|
data
|
{
let
annotated
:
RsAnnotated
<
PyObject
>
=
annotated
.map_data
(|
data
|
{
let
result
=
Python
::
with_gil
(|
py
|
match
pythonize
::
pythonize
(
py
,
&
data
)
{
let
result
=
Python
::
with_gil
(|
py
|
match
pythonize
::
pythonize
(
py
,
&
data
)
{
Ok
(
pyobj
)
=>
Ok
(
pyobj
.into
()),
Ok
(
pyobj
)
=>
Ok
(
pyobj
.into
()),
...
...
lib/llm/src/protocols/common/llm_backend.rs
View file @
b4ddca99
...
@@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
...
@@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
pub
use
super
::
preprocessor
::
PreprocessedRequest
;
pub
use
super
::
preprocessor
::
PreprocessedRequest
;
pub
use
super
::
FinishReason
;
pub
use
super
::
FinishReason
;
use
crate
::
protocols
::
TokenIdType
;
use
crate
::
protocols
::
TokenIdType
;
use
dynamo_runtime
::
protocols
::
maybe_error
::
MaybeError
;
pub
type
TokenType
=
Option
<
String
>
;
pub
type
TokenType
=
Option
<
String
>
;
pub
type
LogProbs
=
Vec
<
f64
>
;
pub
type
LogProbs
=
Vec
<
f64
>
;
...
@@ -134,6 +135,20 @@ impl LLMEngineOutput {
...
@@ -134,6 +135,20 @@ impl LLMEngineOutput {
}
}
}
}
impl
MaybeError
for
LLMEngineOutput
{
fn
from_err
(
err
:
Box
<
dyn
std
::
error
::
Error
>
)
->
Self
{
LLMEngineOutput
::
error
(
format!
(
"{:?}"
,
err
))
}
fn
err
(
&
self
)
->
Option
<
Box
<
dyn
std
::
error
::
Error
>>
{
if
let
Some
(
FinishReason
::
Error
(
err_msg
))
=
&
self
.finish_reason
{
Some
(
anyhow
::
Error
::
msg
(
err_msg
.clone
())
.into
())
}
else
{
None
}
}
}
/// Raw output from embedding engines containing embedding vectors
/// Raw output from embedding engines containing embedding vectors
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
pub
struct
EmbeddingsEngineOutput
{
pub
struct
EmbeddingsEngineOutput
{
...
@@ -144,3 +159,26 @@ pub struct EmbeddingsEngineOutput {
...
@@ -144,3 +159,26 @@ pub struct EmbeddingsEngineOutput {
pub
prompt_tokens
:
u32
,
pub
prompt_tokens
:
u32
,
pub
total_tokens
:
u32
,
pub
total_tokens
:
u32
,
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_maybe_error
()
{
let
output
=
LLMEngineOutput
::
stop
();
assert
!
(
output
.err
()
.is_none
());
assert
!
(
output
.is_ok
());
assert
!
(
!
output
.is_err
());
let
output
=
LLMEngineOutput
::
error
(
"Test error"
.to_string
());
assert_eq!
(
format!
(
"{}"
,
output
.err
()
.unwrap
()),
"Test error"
);
assert
!
(
!
output
.is_ok
());
assert
!
(
output
.is_err
());
let
output
=
LLMEngineOutput
::
from_err
(
anyhow
::
Error
::
msg
(
"Test error 2"
)
.into
());
assert_eq!
(
format!
(
"{}"
,
output
.err
()
.unwrap
()),
"Test error 2"
);
assert
!
(
!
output
.is_ok
());
assert
!
(
output
.is_err
());
}
}
lib/runtime/src/pipeline/network.rs
View file @
b4ddca99
...
@@ -323,3 +323,54 @@ impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> {
...
@@ -323,3 +323,54 @@ impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> {
pub
trait
PushWorkHandler
:
Send
+
Sync
{
pub
trait
PushWorkHandler
:
Send
+
Sync
{
async
fn
handle_payload
(
&
self
,
payload
:
Bytes
)
->
Result
<
(),
PipelineError
>
;
async
fn
handle_payload
(
&
self
,
payload
:
Bytes
)
->
Result
<
(),
PipelineError
>
;
}
}
/*
/// `NetworkStreamWrapper` is a simple wrapper used to detect proper stream termination
/// in network communication between ingress and egress components.
///
/// **Purpose**: This wrapper solves the problem of detecting whether a stream ended
/// gracefully or was cut off prematurely (e.g., due to network issues).
///
/// **Design Rationale**:
/// - Cannot use `Annotated` directly because the generic type `U` varies:
/// - Sometimes `U = Annotated<...>`
/// - Sometimes `U = LLMEngineOutput<...>`
/// - Using `Annotated` would require double-wrapping like `Annotated<Annotated<...>>`
/// - A simple wrapper is cleaner and more straightforward
///
/// **Stream Flow**:
/// ```
/// At AsyncEngine:
/// response 1 -> response 2 -> response 3 -> <end>
///
/// Between ingress/egress:
/// response 1 <end=false> -> response 2 <end=false> -> response 3 <end=false> -> (null) <end=true>
///
/// At client:
/// response 1 -> response 2 -> response 3 -> <end>
/// ```
///
/// **Error Handling**:
/// If the stream is cut off before proper termination, the egress is responsible for
/// injecting an error response to communicate the incomplete stream to the client:
/// ```
/// At AsyncEngine:
/// response 1 -> ... <without end flag>
///
/// At egress:
/// response 1 <end=false> -> <stream ended without end flag -> convert to error>
///
/// At client:
/// response 1 -> error response
/// ```
///
/// The detection must be done at egress level because premature stream termination
/// can be due to network issues that only the egress component can detect.
*/
/// TODO: Detect end-of-stream using Server-Sent Events (SSE). This will be removed.
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
NetworkStreamWrapper
<
U
>
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
data
:
Option
<
U
>
,
pub
complete_final
:
bool
,
}
lib/runtime/src/pipeline/network/egress/addressed_router.rs
View file @
b4ddca99
...
@@ -17,7 +17,8 @@ use async_nats::client::Client;
...
@@ -17,7 +17,8 @@ use async_nats::client::Client;
use
tracing
as
log
;
use
tracing
as
log
;
use
super
::
*
;
use
super
::
*
;
use
crate
::
Result
;
use
crate
::{
protocols
::
maybe_error
::
MaybeError
,
Result
};
use
tokio_stream
::{
wrappers
::
ReceiverStream
,
StreamExt
,
StreamNotifyClose
};
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[serde(rename_all
=
"snake_case"
)]
#[serde(rename_all
=
"snake_case"
)]
...
@@ -80,7 +81,7 @@ impl AddressedPushRouter {
...
@@ -80,7 +81,7 @@ impl AddressedPushRouter {
impl
<
T
,
U
>
AsyncEngine
<
SingleIn
<
AddressedRequest
<
T
>>
,
ManyOut
<
U
>
,
Error
>
for
AddressedPushRouter
impl
<
T
,
U
>
AsyncEngine
<
SingleIn
<
AddressedRequest
<
T
>>
,
ManyOut
<
U
>
,
Error
>
for
AddressedPushRouter
where
where
T
:
Data
+
Serialize
,
T
:
Data
+
Serialize
,
U
:
Data
+
for
<
'de
>
Deserialize
<
'de
>
,
U
:
Data
+
for
<
'de
>
Deserialize
<
'de
>
+
MaybeError
,
{
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
AddressedRequest
<
T
>>
)
->
Result
<
ManyOut
<
U
>
,
Error
>
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
AddressedRequest
<
T
>>
)
->
Result
<
ManyOut
<
U
>
,
Error
>
{
let
request_id
=
request
.context
()
.id
()
.to_string
();
let
request_id
=
request
.context
()
.id
()
.to_string
();
...
@@ -160,16 +161,49 @@ where
...
@@ -160,16 +161,49 @@ where
.map_err
(|
_
|
PipelineError
::
DetatchedStreamReceiver
)
?
.map_err
(|
_
|
PipelineError
::
DetatchedStreamReceiver
)
?
.map_err
(
PipelineError
::
ConnectionFailed
)
?
;
.map_err
(
PipelineError
::
ConnectionFailed
)
?
;
let
stream
=
tokio_stream
::
wrappers
::
ReceiverStream
::
new
(
response_stream
.rx
);
// TODO: Detect end-of-stream using Server-Sent Events (SSE)
let
mut
is_complete_final
=
false
;
let
stream
=
stream
.filter_map
(|
msg
|
async
move
{
let
stream
=
tokio_stream
::
StreamNotifyClose
::
new
(
match
serde_json
::
from_slice
::
<
U
>
(
&
msg
)
{
tokio_stream
::
wrappers
::
ReceiverStream
::
new
(
response_stream
.rx
),
Ok
(
r
)
=>
Some
(
r
),
)
Err
(
err
)
=>
{
.filter_map
(
move
|
res
|
{
let
json_str
=
String
::
from_utf8_lossy
(
&
msg
);
if
let
Some
(
res_bytes
)
=
res
{
log
::
warn!
(
%
err
,
%
json_str
,
"Failed deserializing JSON to response"
);
if
is_complete_final
{
None
return
Some
(
U
::
from_err
(
Error
::
msg
(
"Response received after generation ended - this should never happen"
,
)
.into
(),
));
}
match
serde_json
::
from_slice
::
<
NetworkStreamWrapper
<
U
>>
(
&
res_bytes
)
{
Ok
(
item
)
=>
{
is_complete_final
=
item
.complete_final
;
if
let
Some
(
data
)
=
item
.data
{
Some
(
data
)
}
else
if
is_complete_final
{
None
}
else
{
Some
(
U
::
from_err
(
Error
::
msg
(
"Empty response received - this should never happen"
)
.into
(),
))
}
}
Err
(
err
)
=>
{
// legacy log print
let
json_str
=
String
::
from_utf8_lossy
(
&
res_bytes
);
log
::
warn!
(
%
err
,
%
json_str
,
"Failed deserializing JSON to response"
);
Some
(
U
::
from_err
(
Error
::
new
(
err
)
.into
()))
}
}
}
}
else
if
is_complete_final
{
None
}
else
{
Some
(
U
::
from_err
(
Error
::
msg
(
"Stream ended before generation completed"
)
.into
(),
))
}
}
});
});
...
...
lib/runtime/src/pipeline/network/egress/push_router.rs
View file @
b4ddca99
...
@@ -13,6 +13,16 @@
...
@@ -13,6 +13,16 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
super
::{
AsyncEngineContextProvider
,
ResponseStream
};
use
crate
::{
component
::{
Client
,
Endpoint
,
InstanceSource
},
engine
::{
AsyncEngine
,
Data
},
pipeline
::{
error
::
PipelineErrorExt
,
AddressedPushRouter
,
AddressedRequest
,
Error
,
ManyOut
,
SingleIn
,
},
protocols
::
maybe_error
::
MaybeError
,
traits
::
DistributedRuntimeProvider
,
};
use
async_nats
::
client
::{
use
async_nats
::
client
::{
RequestError
as
NatsRequestError
,
RequestErrorKind
::
NoResponders
as
NatsNoResponders
,
RequestError
as
NatsRequestError
,
RequestErrorKind
::
NoResponders
as
NatsNoResponders
,
};
};
...
@@ -27,15 +37,7 @@ use std::{
...
@@ -27,15 +37,7 @@ use std::{
Arc
,
Arc
,
},
},
};
};
use
tokio_stream
::
StreamExt
;
use
crate
::{
component
::{
Client
,
Endpoint
,
InstanceSource
},
engine
::{
AsyncEngine
,
Data
},
pipeline
::{
error
::
PipelineErrorExt
,
AddressedPushRouter
,
AddressedRequest
,
Error
,
ManyOut
,
SingleIn
,
},
traits
::
DistributedRuntimeProvider
,
};
#[derive(Clone)]
#[derive(Clone)]
pub
struct
PushRouter
<
T
,
U
>
pub
struct
PushRouter
<
T
,
U
>
...
@@ -94,7 +96,7 @@ async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPu
...
@@ -94,7 +96,7 @@ async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPu
impl
<
T
,
U
>
PushRouter
<
T
,
U
>
impl
<
T
,
U
>
PushRouter
<
T
,
U
>
where
where
T
:
Data
+
Serialize
,
T
:
Data
+
Serialize
,
U
:
Data
+
for
<
'de
>
Deserialize
<
'de
>
,
U
:
Data
+
for
<
'de
>
Deserialize
<
'de
>
+
MaybeError
,
{
{
pub
async
fn
from_client
(
client
:
Client
,
router_mode
:
RouterMode
)
->
anyhow
::
Result
<
Self
>
{
pub
async
fn
from_client
(
client
:
Client
,
router_mode
:
RouterMode
)
->
anyhow
::
Result
<
Self
>
{
let
addressed
=
addressed_router
(
&
client
.endpoint
)
.await
?
;
let
addressed
=
addressed_router
(
&
client
.endpoint
)
.await
?
;
...
@@ -109,51 +111,44 @@ where
...
@@ -109,51 +111,44 @@ where
/// Issue a request to the next available instance in a round-robin fashion
/// Issue a request to the next available instance in a round-robin fashion
pub
async
fn
round_robin
(
&
self
,
request
:
SingleIn
<
T
>
)
->
anyhow
::
Result
<
ManyOut
<
U
>>
{
pub
async
fn
round_robin
(
&
self
,
request
:
SingleIn
<
T
>
)
->
anyhow
::
Result
<
ManyOut
<
U
>>
{
let
slf
=
self
;
let
counter
=
self
.round_robin_counter
.fetch_add
(
1
,
Ordering
::
Relaxed
);
let
routing_algorithm
=
move
||
async
move
{
let
counter
=
slf
.round_robin_counter
.fetch_add
(
1
,
Ordering
::
Relaxed
);
let
instance_id
=
{
let
instances
=
slf
.client
.instances_avail
()
.await
;
let
count
=
instances
.len
();
if
count
==
0
{
return
Err
(
anyhow
::
anyhow!
(
"no instances found for endpoint {:?}"
,
slf
.client.endpoint
.etcd_root
()
));
}
let
offset
=
counter
%
count
as
u64
;
instances
[
offset
as
usize
]
.id
()
};
tracing
::
trace!
(
"round robin router selected {instance_id}"
);
Ok
(
instance_id
)
let
instance_id
=
{
let
instances
=
self
.client
.instances_avail
()
.await
;
let
count
=
instances
.len
();
if
count
==
0
{
return
Err
(
anyhow
::
anyhow!
(
"no instances found for endpoint {:?}"
,
self
.client.endpoint
.etcd_root
()
));
}
let
offset
=
counter
%
count
as
u64
;
instances
[
offset
as
usize
]
.id
()
};
};
self
.generate_with_fault_tolerance
(
routing_algorithm
,
request
)
tracing
::
trace!
(
"round robin router selected {instance_id}"
);
self
.generate_with_fault_detection
(
instance_id
,
request
)
.await
.await
}
}
/// Issue a request to a random endpoint
/// Issue a request to a random endpoint
pub
async
fn
random
(
&
self
,
request
:
SingleIn
<
T
>
)
->
anyhow
::
Result
<
ManyOut
<
U
>>
{
pub
async
fn
random
(
&
self
,
request
:
SingleIn
<
T
>
)
->
anyhow
::
Result
<
ManyOut
<
U
>>
{
let
slf
=
self
;
let
instance_id
=
{
let
routing_algorithm
=
move
||
async
move
{
let
instances
=
self
.client
.instances_avail
()
.await
;
let
instance_id
=
{
let
count
=
instances
.len
();
let
instances
=
slf
.client
.instances_avail
()
.await
;
if
count
==
0
{
let
count
=
instances
.len
();
return
Err
(
anyhow
::
anyhow!
(
if
count
==
0
{
"no instances found for endpoint {:?}"
,
return
Err
(
anyhow
::
anyhow!
(
self
.client.endpoint
.etcd_root
()
"no instances found for endpoint {:?}"
,
));
slf
.client.endpoint
.etcd_root
()
}
));
let
counter
=
rand
::
rng
()
.random
::
<
u64
>
();
}
let
offset
=
counter
%
count
as
u64
;
let
counter
=
rand
::
rng
()
.random
::
<
u64
>
();
instances
[
offset
as
usize
]
.id
()
let
offset
=
counter
%
count
as
u64
;
instances
[
offset
as
usize
]
.id
()
};
tracing
::
trace!
(
"random router selected {instance_id}"
);
Ok
(
instance_id
)
};
};
self
.generate_with_fault_tolerance
(
routing_algorithm
,
request
)
tracing
::
trace!
(
"random router selected {instance_id}"
);
self
.generate_with_fault_detection
(
instance_id
,
request
)
.await
.await
}
}
...
@@ -163,22 +158,19 @@ where
...
@@ -163,22 +158,19 @@ where
request
:
SingleIn
<
T
>
,
request
:
SingleIn
<
T
>
,
instance_id
:
i64
,
instance_id
:
i64
,
)
->
anyhow
::
Result
<
ManyOut
<
U
>>
{
)
->
anyhow
::
Result
<
ManyOut
<
U
>>
{
let
slf
=
self
;
let
found
=
{
let
routing_algorithm
=
move
||
async
move
{
let
instances
=
self
.client
.instances_avail
()
.await
;
let
found
=
{
instances
.iter
()
.any
(|
ep
|
ep
.id
()
==
instance_id
)
let
instances
=
slf
.client
.instances_avail
()
.await
;
instances
.iter
()
.any
(|
ep
|
ep
.id
()
==
instance_id
)
};
if
!
found
{
return
Err
(
anyhow
::
anyhow!
(
"instance_id={instance_id} not found for endpoint {:?}"
,
slf
.client.endpoint
.etcd_root
()
));
}
Ok
(
instance_id
)
};
};
self
.generate_with_fault_tolerance
(
routing_algorithm
,
request
)
if
!
found
{
return
Err
(
anyhow
::
anyhow!
(
"instance_id={instance_id} not found for endpoint {:?}"
,
self
.client.endpoint
.etcd_root
()
));
}
self
.generate_with_fault_detection
(
instance_id
,
request
)
.await
.await
}
}
...
@@ -190,29 +182,45 @@ where
...
@@ -190,29 +182,45 @@ where
self
.addressed
.generate
(
request
)
.await
self
.addressed
.generate
(
request
)
.await
}
}
async
fn
generate_with_fault_
tolerance
<
F
,
R
>
(
async
fn
generate_with_fault_
detection
(
&
self
,
&
self
,
routing_algorithm
:
F
,
instance_id
:
i64
,
request
:
SingleIn
<
T
>
,
request
:
SingleIn
<
T
>
,
)
->
anyhow
::
Result
<
ManyOut
<
U
>>
)
->
anyhow
::
Result
<
ManyOut
<
U
>>
{
where
F
:
FnOnce
()
->
R
,
R
:
Future
<
Output
=
anyhow
::
Result
<
i64
>>
,
{
let
instance_id
=
routing_algorithm
()
.await
?
;
let
subject
=
self
.client.endpoint
.subject_to
(
instance_id
);
let
subject
=
self
.client.endpoint
.subject_to
(
instance_id
);
let
request
=
request
.map
(|
req
|
AddressedRequest
::
new
(
req
,
subject
));
let
request
=
request
.map
(|
req
|
AddressedRequest
::
new
(
req
,
subject
));
let
stream
=
self
.addressed
.generate
(
request
)
.await
;
let
stream
:
anyhow
::
Result
<
ManyOut
<
U
>>
=
self
.addressed
.generate
(
request
)
.await
;
if
let
Some
(
err
)
=
stream
.as_ref
()
.err
()
{
match
stream
{
if
let
Some
(
req_err
)
=
err
.downcast_ref
::
<
NatsRequestError
>
()
{
Ok
(
stream
)
=>
{
if
matches!
(
req_err
.kind
(),
NatsNoResponders
)
{
let
engine_ctx
=
stream
.context
();
self
.client
.report_instance_down
(
instance_id
)
.await
;
let
client
=
self
.client
.clone
();
let
stream
=
stream
.then
(
move
|
res
|
{
let
mut
report_instance_down
:
Option
<
(
Client
,
i64
)
>
=
None
;
if
let
Some
(
err
)
=
res
.err
()
{
const
STREAM_ERR_MSG
:
&
str
=
"Stream ended before generation completed"
;
if
format!
(
"{:?}"
,
err
)
==
STREAM_ERR_MSG
{
report_instance_down
=
Some
((
client
.clone
(),
instance_id
));
}
}
async
move
{
if
let
Some
((
client
,
instance_id
))
=
report_instance_down
{
client
.report_instance_down
(
instance_id
)
.await
;
}
res
}
});
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
engine_ctx
))
}
Err
(
err
)
=>
{
if
let
Some
(
req_err
)
=
err
.downcast_ref
::
<
NatsRequestError
>
()
{
if
matches!
(
req_err
.kind
(),
NatsNoResponders
)
{
self
.client
.report_instance_down
(
instance_id
)
.await
;
}
}
}
Err
(
err
)
}
}
}
}
stream
}
}
}
}
...
@@ -220,7 +228,7 @@ where
...
@@ -220,7 +228,7 @@ where
impl
<
T
,
U
>
AsyncEngine
<
SingleIn
<
T
>
,
ManyOut
<
U
>
,
Error
>
for
PushRouter
<
T
,
U
>
impl
<
T
,
U
>
AsyncEngine
<
SingleIn
<
T
>
,
ManyOut
<
U
>
,
Error
>
for
PushRouter
<
T
,
U
>
where
where
T
:
Data
+
Serialize
,
T
:
Data
+
Serialize
,
U
:
Data
+
for
<
'de
>
Deserialize
<
'de
>
,
U
:
Data
+
for
<
'de
>
Deserialize
<
'de
>
+
MaybeError
,
{
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
T
>
)
->
Result
<
ManyOut
<
U
>
,
Error
>
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
T
>
)
->
Result
<
ManyOut
<
U
>
,
Error
>
{
match
self
.client.instance_source
.as_ref
()
{
match
self
.client.instance_source
.as_ref
()
{
...
...
lib/runtime/src/pipeline/network/ingress/push_handler.rs
View file @
b4ddca99
...
@@ -97,16 +97,37 @@ where
...
@@ -97,16 +97,37 @@ where
let
context
=
stream
.context
();
let
context
=
stream
.context
();
// TODO: Detect end-of-stream using Server-Sent Events (SSE)
let
mut
send_complete_final
=
true
;
while
let
Some
(
resp
)
=
stream
.next
()
.await
{
while
let
Some
(
resp
)
=
stream
.next
()
.await
{
tracing
::
trace!
(
"Sending response: {:?}"
,
resp
);
tracing
::
trace!
(
"Sending response: {:?}"
,
resp
);
let
resp_bytes
=
serde_json
::
to_vec
(
&
resp
)
let
resp_wrapper
=
NetworkStreamWrapper
{
data
:
Some
(
resp
),
complete_final
:
false
,
};
let
resp_bytes
=
serde_json
::
to_vec
(
&
resp_wrapper
)
.expect
(
"fatal error: invalid response object - this should never happen"
);
.expect
(
"fatal error: invalid response object - this should never happen"
);
if
(
publisher
.send
(
resp_bytes
.into
())
.await
)
.is_err
()
{
if
(
publisher
.send
(
resp_bytes
.into
())
.await
)
.is_err
()
{
tracing
::
error!
(
"Failed to publish response for stream {}"
,
context
.id
());
tracing
::
error!
(
"Failed to publish response for stream {}"
,
context
.id
());
context
.stop_generating
();
context
.stop_generating
();
send_complete_final
=
false
;
break
;
break
;
}
}
}
}
if
send_complete_final
{
let
resp_wrapper
=
NetworkStreamWrapper
::
<
U
>
{
data
:
None
,
complete_final
:
true
,
};
let
resp_bytes
=
serde_json
::
to_vec
(
&
resp_wrapper
)
.expect
(
"fatal error: invalid response object - this should never happen"
);
if
(
publisher
.send
(
resp_bytes
.into
())
.await
)
.is_err
()
{
tracing
::
error!
(
"Failed to publish complete final for stream {}"
,
context
.id
()
);
}
}
Ok
(())
Ok
(())
}
}
...
...
lib/runtime/src/protocols.rs
View file @
b4ddca99
...
@@ -19,6 +19,7 @@ use std::str::FromStr;
...
@@ -19,6 +19,7 @@ use std::str::FromStr;
use
crate
::
pipeline
::
PipelineError
;
use
crate
::
pipeline
::
PipelineError
;
pub
mod
annotated
;
pub
mod
annotated
;
pub
mod
maybe_error
;
pub
type
LeaseId
=
i64
;
pub
type
LeaseId
=
i64
;
...
...
lib/runtime/src/protocols/annotated.rs
View file @
b4ddca99
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
use
super
::
*
;
use
super
::
*
;
use
crate
::{
error
,
Result
};
use
crate
::{
error
,
Result
};
use
maybe_error
::
MaybeError
;
pub
trait
AnnotationsProvider
{
pub
trait
AnnotationsProvider
{
fn
annotations
(
&
self
)
->
Option
<
Vec
<
String
>>
;
fn
annotations
(
&
self
)
->
Option
<
Vec
<
String
>>
;
...
@@ -28,7 +29,7 @@ pub trait AnnotationsProvider {
...
@@ -28,7 +29,7 @@ pub trait AnnotationsProvider {
/// Our services have the option of returning an "annotated" stream, which allows use
/// Our services have the option of returning an "annotated" stream, which allows use
/// to include additional information with each delta. This is useful for debugging,
/// to include additional information with each delta. This is useful for debugging,
/// performance benchmarking, and improved observability.
/// performance benchmarking, and improved observability.
#[derive(Serialize,
Deserialize,
Debug)]
#[derive(Serialize,
Deserialize,
Clone,
Debug)]
pub
struct
Annotated
<
R
>
{
pub
struct
Annotated
<
R
>
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
data
:
Option
<
R
>
,
pub
data
:
Option
<
R
>
,
...
@@ -146,6 +147,28 @@ impl<R> Annotated<R> {
...
@@ -146,6 +147,28 @@ impl<R> Annotated<R> {
}
}
}
}
impl
<
R
>
MaybeError
for
Annotated
<
R
>
where
R
:
for
<
'de
>
Deserialize
<
'de
>
+
Serialize
,
{
fn
from_err
(
err
:
Box
<
dyn
std
::
error
::
Error
>
)
->
Self
{
Annotated
::
from_error
(
format!
(
"{:?}"
,
err
))
}
fn
err
(
&
self
)
->
Option
<
Box
<
dyn
std
::
error
::
Error
>>
{
if
self
.is_error
()
{
if
let
Some
(
comment
)
=
&
self
.comment
{
if
!
comment
.is_empty
()
{
return
Some
(
anyhow
::
Error
::
msg
(
comment
.join
(
"; "
))
.into
());
}
}
Some
(
anyhow
::
Error
::
msg
(
"unknown error"
)
.into
())
}
else
{
None
}
}
}
// impl<R> Annotated<R>
// impl<R> Annotated<R>
// where
// where
// R: for<'de> Deserialize<'de> + Serialize,
// R: for<'de> Deserialize<'de> + Serialize,
...
@@ -166,3 +189,27 @@ impl<R> Annotated<R> {
...
@@ -166,3 +189,27 @@ impl<R> Annotated<R> {
// Box::pin(stream)
// Box::pin(stream)
// }
// }
// }
// }
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_maybe_error
()
{
let
annotated
=
Annotated
::
from_data
(
"Test data"
.to_string
());
assert
!
(
annotated
.err
()
.is_none
());
assert
!
(
annotated
.is_ok
());
assert
!
(
!
annotated
.is_err
());
let
annotated
=
Annotated
::
<
String
>
::
from_error
(
"Test error 2"
.to_string
());
assert_eq!
(
format!
(
"{}"
,
annotated
.err
()
.unwrap
()),
"Test error 2"
);
assert
!
(
!
annotated
.is_ok
());
assert
!
(
annotated
.is_err
());
let
annotated
=
Annotated
::
<
String
>
::
from_err
(
anyhow
::
Error
::
msg
(
"Test error 3"
.to_string
())
.into
());
assert_eq!
(
format!
(
"{}"
,
annotated
.err
()
.unwrap
()),
"Test error 3"
);
assert
!
(
!
annotated
.is_ok
());
assert
!
(
annotated
.is_err
());
}
}
lib/runtime/src/protocols/maybe_error.rs
0 → 100644
View file @
b4ddca99
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
error
::
Error
;
pub
trait
MaybeError
{
/// Construct an instance from an error.
fn
from_err
(
err
:
Box
<
dyn
Error
>
)
->
Self
;
/// Construct into an error instance.
fn
err
(
&
self
)
->
Option
<
Box
<
dyn
Error
>>
;
/// Check if the current instance represents a success.
fn
is_ok
(
&
self
)
->
bool
{
!
self
.is_err
()
}
/// Check if the current instance represents an error.
fn
is_err
(
&
self
)
->
bool
{
self
.err
()
.is_some
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
struct
TestError
{
message
:
String
,
}
impl
MaybeError
for
TestError
{
fn
from_err
(
err
:
Box
<
dyn
Error
>
)
->
Self
{
TestError
{
message
:
err
.to_string
(),
}
}
fn
err
(
&
self
)
->
Option
<
Box
<
dyn
Error
>>
{
Some
(
anyhow
::
Error
::
msg
(
self
.message
.clone
())
.into
())
}
}
#[test]
fn
test_maybe_error_default_implementations
()
{
let
err
=
TestError
::
from_err
(
anyhow
::
Error
::
msg
(
"Test error"
.to_string
())
.into
());
assert_eq!
(
format!
(
"{}"
,
err
.err
()
.unwrap
()),
"Test error"
);
assert
!
(
!
err
.is_ok
());
assert
!
(
err
.is_err
());
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment