Commit 5ed8c1c0 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: rust - initial commit

the journey begins
parent 4017bd18
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
pub mod push_endpoint;
pub mod push_handler;
use super::*;
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::*;
use anyhow::Result;
use async_nats::service::endpoint::Endpoint;
use derive_builder::Builder;
use tokio_util::sync::CancellationToken;
use tracing as log;
#[derive(Builder)]
pub struct PushEndpoint {
pub service_handler: Arc<dyn PushWorkHandler>,
pub cancellation_token: CancellationToken,
}
/// version of crate
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
impl PushEndpoint {
pub fn builder() -> PushEndpointBuilder {
PushEndpointBuilder::default()
}
pub async fn start(self, endpoint: Endpoint) -> Result<()> {
let mut endpoint = endpoint;
loop {
let req = tokio::select! {
biased;
// await on service request
req = endpoint.next() => {
req
}
// process shutdown
_ = self.cancellation_token.cancelled() => {
// log::trace!(worker_id, "Shutting down service {}", self.endpoint.name);
if let Err(e) = endpoint.stop().await {
log::warn!("Failed to stop NATS service: {:?}", e);
}
break;
}
};
if let Some(req) = req {
let response = "".to_string();
if let Err(e) = req.respond(Ok(response.into())).await {
log::warn!("Failed to respond to request; this may indicate the request has shutdown: {:?}", e);
}
let ingress = self.service_handler.clone();
let worker_id = "".to_string();
tokio::spawn(async move {
log::trace!(worker_id, "handling new request");
let result = ingress.handle_payload(req.message.payload).await;
log::trace!(worker_id, "request handled: {:?}", result);
});
} else {
break;
}
}
Ok(())
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::*;
use anyhow::Result;
use serde::{Deserialize, Serialize};
#[async_trait]
impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
where
T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
U: Data + Serialize + std::fmt::Debug,
{
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
// decode the control message and the request
let msg = TwoPartCodec::default()
.decode_message(payload)?
.into_message_type();
// we must have a header and a body
// it will be held by this closure as a Some(permit)
let (control_msg, request) = match msg {
TwoPartMessageType::HeaderAndData(header, data) => {
tracing::trace!(
"received two part message with ctrl: {} bytes, data: {} bytes",
header.len(),
data.len()
);
let control_msg: RequestControlMessage = serde_json::from_slice(&header).unwrap();
let request: T = serde_json::from_slice(&data)?;
(control_msg, request)
}
_ => {
return Err(PipelineError::Generic(String::from("Unexpected message from work queue; unable extract a TwoPartMessage with a header and data")));
}
};
// extend request with context
tracing::trace!("received control message: {:?}", control_msg);
tracing::trace!("received request: {:?}", request);
let request: context::Context<T> = Context::with_id(request, control_msg.id);
// todo - eventually have a handler class which will returned an abstracted object, but for now,
// we only support tcp here, so we can just unwrap the connection info
tracing::trace!("creating tcp response stream");
let mut publisher = tcp::client::TcpClient::create_response_steam(
request.context(),
control_msg.connection_info,
)
.await
.map_err(|e| {
PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
})?;
tracing::trace!("calling generate");
let stream = self
.segment
.get()
.expect("segment not set")
.generate(request)
.await
.map_err(PipelineError::GenerateError);
// the prolouge is sent to the client to indicate that the stream is ready to receive data
// or if teh generate call failed, the error is sent to the client
let mut stream = match stream {
Ok(stream) => {
tracing::trace!("Successfully generated response stream; sending prologue");
let _result = publisher.send_prologue(None).await;
stream
}
Err(e) => {
tracing::error!("Failed to generate response stream: {:?}", e);
let _result = publisher.send_prologue(Some(e.to_string())).await;
Err(e)?
}
};
let context = stream.context();
while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp);
let resp_bytes = serde_json::to_vec(&resp)
.expect("fatal error: invalid response object - this should never happen");
if (publisher.send(resp_bytes.into()).await).is_err() {
tracing::error!("Failed to publish response for stream {}", context.id());
context.stop_generating();
break;
}
}
Ok(())
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! TCP Transport Module
//!
//! The TCP Transport module consists of two main components: Client and Server. The Client is
//! the downstream node that is responsible for connecting back to the upstream node (Server).
//!
//! Both Client and Server are given a Stream object that they can specialize for their specific
//! needs, i.e. if they are SingleIn/ManyIn or SingleOut/ManyOut.
//!
//! The Request object will carry the Transport Type and Connection details, i.e. how the receiver
//! of a Request is able to communicate back to the source of the Request.
//!
//! There are two types of TcpStream:
//! - CallHome stream - the address for the listening socket is forward via some mechanism which then
//! connects back to the source of the CallHome stream. To match the socket with an awaiting data
//! stream, the CallHomeHandshake is used.
pub mod client;
pub mod server;
use serde::{Deserialize, Serialize};
#[allow(unused_imports)]
use super::{
codec::TwoPartCodec, ConnectionInfo, PendingConnections, RegisteredStream, ResponseService,
StreamOptions, StreamReceiver, StreamSender, StreamType,
};
const TCP_TRANSPORT: &str = "tcp_server";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TcpStreamConnectionInfo {
pub address: String,
pub subject: String,
pub context: String,
pub stream_type: StreamType,
}
impl From<TcpStreamConnectionInfo> for ConnectionInfo {
fn from(info: TcpStreamConnectionInfo) -> Self {
// Need to consider the below. If failure should be fatal, keep the below with .expect()
// But if there is a default value, we can use:
// unwrap_or_else(|e| {
// eprintln!("Failed to serialize TcpStreamConnectionInfo: {:?}", e);
// "{}".to_string() // Provide a fallback empty JSON string or default value
ConnectionInfo {
transport: TCP_TRANSPORT.to_string(),
info: serde_json::to_string(&info)
.expect("Failed to serialize TcpStreamConnectionInfo"),
}
}
}
impl TryFrom<ConnectionInfo> for TcpStreamConnectionInfo {
type Error = String;
fn try_from(info: ConnectionInfo) -> Result<Self, Self::Error> {
if info.transport != TCP_TRANSPORT {
return Err(format!(
"Invalid transport; TcpClient requires the transport to be `tcp_server`; however {} was passed",
info.transport
));
}
serde_json::from_str(&info.info)
.map_err(|e| format!("Failed parse ConnectionInfo: {:?}", e))
}
}
/// First message sent over a CallHome stream which will map the newly created socket to a specific
/// response data stream which was registered with the same subject.
///
/// This is a transport specific message as part of forming/completing a CallHome TcpStream.
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CallHomeHandshake {
subject: String,
stream_type: StreamType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ControlMessage {
Stop,
Kill,
}
#[cfg(test)]
mod tests {
use crate::engine::AsyncEngineContextProvider;
use super::*;
use crate::pipeline::Context;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestMessage {
foo: String,
}
#[tokio::test]
async fn test_tcp_stream_client_server() {
println!("Test Started");
let options = server::ServerOptions::builder().port(9124).build().unwrap();
println!("Test Started");
let server = server::TcpStreamServer::new(options).await.unwrap();
println!("Server created");
let context_rank0 = Context::new(());
let options = StreamOptions::builder()
.context(context_rank0.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(options).await;
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
// set up the other rank
let context_rank1 = Context::with_id((), context_rank0.id().to_string());
// connect to the server socket
let mut send_stream =
client::TcpClient::create_response_steam(context_rank1.context(), connection_info)
.await
.unwrap();
println!("Client connected");
// the client can now setup it's end of the stream and if it errors, it can send a message
// to the server to stop the stream
//
// this step must be done before the next step on the server can complete, i.e.
// the server's stream is now blocked on receiving the prologue message
//
// let's improve this and use an enum like Ok/Err; currently, None means good-to-go, and
// Some(String) means an error happened on this downstream node and we need to alert the
// upstream node that an error occurred
send_stream.send_prologue(None).await.unwrap();
// [server] next - now pending connections should be connected
let recv_stream = pending_connection
.recv_stream
.unwrap()
.stream_provider
.await
.unwrap();
println!("Server paired");
let msg = TestMessage {
foo: "bar".to_string(),
};
let payload = serde_json::to_vec(&msg).unwrap();
send_stream.send(payload.into()).await.unwrap();
println!("Client sent message");
let data = recv_stream.unwrap().rx.recv().await.unwrap();
println!("Server received message");
let recv_msg = serde_json::from_slice::<TestMessage>(&data).unwrap();
assert_eq!(msg.foo, recv_msg.foo);
println!("message match");
drop(send_stream);
// let data = recv_stream.rx.recv().await;
// assert!(data.is_none());
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use std::sync::Arc;
use futures::{SinkExt, StreamExt};
use tokio::{io::AsyncWriteExt, net::TcpStream};
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing as log;
use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
use crate::engine::AsyncEngineContext;
use crate::pipeline::network::{
codec::{TwoPartCodec, TwoPartMessage},
tcp::StreamType,
ConnectionInfo, ResponseStreamPrologue, StreamSender,
}; // Import SinkExt to use the `send` method
#[allow(dead_code)]
pub struct TcpClient {
worker_id: String,
}
impl Default for TcpClient {
fn default() -> Self {
TcpClient {
worker_id: uuid::Uuid::new_v4().to_string(),
}
}
}
impl TcpClient {
pub fn new(worker_id: String) -> Self {
TcpClient { worker_id }
}
async fn connect(address: &str) -> Result<TcpStream, String> {
let socket = TcpStream::connect(address)
.await
.map_err(|e| format!("failed to connect: {:?}", e))?;
socket
.set_nodelay(true)
.map_err(|e| format!("failed to set nodelay: {:?}", e))?;
Ok(socket)
}
pub async fn create_response_steam(
context: Arc<dyn AsyncEngineContext>,
info: ConnectionInfo,
) -> Result<StreamSender, String> {
let info = TcpStreamConnectionInfo::try_from(info)?;
tracing::trace!("Creating response stream for {:?}", info);
if info.stream_type != StreamType::Response {
return Err(format!(
"Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
info.stream_type
));
}
if info.context != context.id() {
return Err(format!(
"Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
context.id(),
info.context
));
}
let stream = TcpClient::connect(&info.address).await?;
let (read_half, write_half) = tokio::io::split(stream);
let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
// this is a oneshot channel that will be used to signal when the stream is closed
// when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
// the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
// so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
// captured by the monitor task
let (mut alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
// monitors the channel for a cancellation signal
// this task exits when the alive_rx half of the oneshot channel is closed or a stop/kill signal is received
tokio::spawn(async move {
loop {
tokio::select! {
msg = framed_reader.next() => {
match msg {
Some(Ok(two_part_msg)) => {
match two_part_msg.optional_parts() {
(Some(bytes), None) => {
let msg: ControlMessage = serde_json::from_slice(bytes).unwrap();
match msg {
ControlMessage::Stop => {
context.stop();
break;
}
ControlMessage::Kill => {
context.kill();
break;
}
}
}
_ => {
// we should not receive this
}
}
}
Some(Err(e)) => {
panic!("failed to decode message from stream: {:?}", e);
// break;
}
None => {
// the stream was closed, we should stop the stream
return;
}
}
}
_ = alive_tx.closed() => {
// the channel was closed, we should stop the stream
break;
}
}
}
// framed_writer.get_mut().shutdown().await.unwrap();
});
// transport specific handshake message
let handshake = CallHomeHandshake {
subject: info.subject,
stream_type: StreamType::Response,
};
let handshake_bytes = serde_json::to_vec(&handshake).unwrap();
let msg = TwoPartMessage::from_header(handshake_bytes.into());
// issue the the first tcp handshake message
framed_writer
.send(msg)
.await
.map_err(|e| format!("failed to send handshake: {:?}", e))?;
// set up the channel to send bytes to the transport layer
let (bytes_tx, mut bytes_rx) = tokio::sync::mpsc::channel(16);
// forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
tokio::spawn(async move {
while let Some(msg) = bytes_rx.recv().await {
if let Err(e) = framed_writer.send(msg).await {
log::trace!(
"failed to send message to stream; possible disconnect: {:?}",
e
);
// TODO - possibly propagate the error upstream
break;
}
}
drop(alive_rx);
if let Err(e) = framed_writer.get_mut().shutdown().await {
log::trace!("failed to shutdown writer: {:?}", e);
}
});
// set up the prologue for the stream
// this might have transport specific metadata in the future
let prologue = Some(ResponseStreamPrologue { error: None });
// create the stream sender
let stream_sender = StreamSender {
tx: bytes_tx,
prologue,
};
Ok(stream_sender)
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use anyhow::Result;
use core::panic;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
use bytes::Bytes;
use derive_builder::Builder;
use futures::StreamExt;
use local_ip_address::{list_afinet_netifas, local_ip};
use serde::{Deserialize, Serialize};
use tokio::{
io::AsyncWriteExt,
sync::{mpsc, oneshot},
};
use tokio_util::codec::{FramedRead, FramedWrite};
use super::{
CallHomeHandshake, PendingConnections, RegisteredStream, StreamOptions, StreamReceiver,
StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
};
use crate::engine::AsyncEngineContext;
use crate::pipeline::{
network::{
codec::{TwoPartMessage, TwoPartMessageType},
tcp::StreamType,
ResponseService, ResponseStreamPrologue,
},
PipelineError,
};
#[allow(dead_code)]
type ResponseType = TwoPartMessage;
#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
pub struct ServerOptions {
#[builder(default = "0")]
pub port: u16,
#[builder(default)]
pub interface: Option<String>,
}
impl ServerOptions {
pub fn builder() -> ServerOptionsBuilder {
ServerOptionsBuilder::default()
}
}
// todo - rename TcpResponseServer
// we may need to disambiguate this and a TcpRequestServer
/// A [`TcpStreamServer`] is a TCP service that listens on a port for incoming response connections.
/// A Response connection is a connection that is established by a client with the intention of sending
/// specific data back to the server. The key differentiating factor is that a [`ResponseServer`] is
/// expecting a connection from a client with an established subject.
pub struct TcpStreamServer {
local_ip: String,
local_port: u16,
state: Arc<Mutex<State>>,
}
// pub struct TcpStreamReceiver {
// address: TcpStreamConnectionInfo,
// state: Arc<Mutex<State>>,
// rx: mpsc::Receiver<ResponseType>,
// }
#[allow(dead_code)]
struct RequestedSendConnection {
context: Arc<dyn AsyncEngineContext>,
connection: oneshot::Sender<Result<StreamSender, String>>,
}
struct RequestedRecvConnection {
context: Arc<dyn AsyncEngineContext>,
connection: oneshot::Sender<Result<StreamReceiver, String>>,
}
// /// When registering a new TcpStream on the server, the registration method will return a [`Connections`] object.
// /// This [`Connections`] object will have two [`oneshot::Receiver`] objects, one for the [`TcpStreamSender`] and one for the [`TcpStreamReceiver`].
// /// The [`Connections`] object can be awaited to get the [`TcpStreamSender`] and [`TcpStreamReceiver`] objects; these objects will
// /// be made available when the matching Client has connected to the server.
// pub struct Connections {
// pub address: TcpStreamConnectionInfo,
// /// The [`oneshot::Receiver`] for the [`TcpStreamSender`]. Awaiting this object will return the [`TcpStreamSender`] object once
// /// the client has connected to the server.
// pub sender: Option<oneshot::Receiver<StreamSender>>,
// /// The [`oneshot::Receiver`] for the [`TcpStreamReceiver`]. Awaiting this object will return the [`TcpStreamReceiver`] object once
// /// the client has connected to the server.
// pub receiver: Option<oneshot::Receiver<StreamReceiver>>,
// }
#[derive(Default)]
struct State {
tx_subjects: HashMap<String, RequestedSendConnection>,
rx_subjects: HashMap<String, RequestedRecvConnection>,
handle: Option<tokio::task::JoinHandle<()>>,
}
impl TcpStreamServer {
pub fn options_builder() -> ServerOptionsBuilder {
ServerOptionsBuilder::default()
}
pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
let local_ip = match options.interface {
Some(interface) => {
let interfaces: HashMap<String, std::net::IpAddr> =
list_afinet_netifas()?.into_iter().collect();
interfaces
.get(&interface)
.ok_or(PipelineError::Generic(format!(
"Interface not found: {}",
interface
)))?
.to_string()
}
None => local_ip().unwrap().to_string(),
};
let state = Arc::new(Mutex::new(State::default()));
let local_port = Self::start(local_ip.clone(), options.port, state.clone())
.await
.map_err(|e| {
PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
})?;
tracing::info!("TcpStreamServer started on {}:{}", local_ip, local_port);
Ok(Arc::new(Self {
local_ip,
local_port,
state,
}))
}
#[allow(clippy::await_holding_lock)]
async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
let addr = format!("{}:{}", local_ip, local_port);
let state_clone = state.clone();
let mut guard = state.lock().await;
if guard.handle.is_some() {
panic!("TcpStreamServer already started");
}
let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
guard.handle = Some(handle);
drop(guard);
let local_port = ready_rx.await??;
Ok(local_port)
}
}
// todo - possible rename ResponseService to ResponseServer
#[async_trait::async_trait]
impl ResponseService for TcpStreamServer {
/// Register a new subject and sender with the response subscriber
/// Produces an RAII object that will deregister the subject when dropped
///
/// we need to register both data in and data out entries
/// there might be forward pipeline that want to consume the data out stream
/// and there might be a response stream that wants to consume the data in stream
/// on registration, we need to specific if we want data-in, data-out or both
/// this will map to the type of service that is runniing, i.e. Single or Many In //
/// Single or Many Out
///
/// todo(ryan) - return a connection object that can be awaited. when successfully connected,
/// can ask for the sender and receiver
///
/// OR
///
/// we make it into register sender and register receiver, both would return a connection object
/// and when a connection is established, we'd get the respective sender or receiver
///
/// the registration probably needs to be done in one-go, so we should use a builder object for
/// requesting a receiver and optional sender
async fn register(&self, options: StreamOptions) -> PendingConnections {
// oneshot channels to pass back the sender and receiver objects
let address = format!("{}:{}", self.local_ip, self.local_port);
tracing::debug!("Registering new TcpStream on {}", address);
let send_stream = if options.enable_request_stream {
let sender_subject = uuid::Uuid::new_v4().to_string();
let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
let connection_info = RequestedSendConnection {
context: options.context.clone(),
connection: pending_sender_tx,
};
let mut state = self.state.lock().await;
state
.tx_subjects
.insert(sender_subject.clone(), connection_info);
let registered_stream = RegisteredStream {
connection_info: TcpStreamConnectionInfo {
address: address.clone(),
subject: sender_subject.clone(),
context: options.context.id().to_string(),
stream_type: StreamType::Request,
}
.into(),
stream_provider: pending_sender_rx,
};
Some(registered_stream)
} else {
None
};
let recv_stream = if options.enable_response_stream {
let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
let receiver_subject = uuid::Uuid::new_v4().to_string();
let connection_info = RequestedRecvConnection {
context: options.context.clone(),
connection: pending_recver_tx,
};
let mut state = self.state.lock().await;
state
.rx_subjects
.insert(receiver_subject.clone(), connection_info);
let registered_stream = RegisteredStream {
connection_info: TcpStreamConnectionInfo {
address: address.clone(),
subject: receiver_subject.clone(),
context: options.context.id().to_string(),
stream_type: StreamType::Response,
}
.into(),
stream_provider: pending_recver_rx,
};
Some(registered_stream)
} else {
None
};
PendingConnections {
send_stream,
recv_stream,
}
}
}
// this method listens on a tcp port for incoming connections
// new connections are expected to send a protocol specific handshake
// for us to determine the subject they are interested in, in this case,
// we expect the first message to be [`FirstMessage`] from which we find
// the sender, then we spawn a task to forward all bytes from the tcp stream
// to the sender
async fn tcp_listener(
addr: String,
state: Arc<Mutex<State>>,
read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
) {
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
let listener = match listener {
Ok(listener) => {
let addr = listener
.local_addr()
.map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
.unwrap();
read_tx
.send(Ok(addr.port()))
.expect("Failed to send ready signal");
listener
}
Err(e) => {
read_tx.send(Err(e)).expect("Failed to send ready signal");
return;
}
};
loop {
let (stream, _addr) = listener.accept().await.unwrap();
stream.set_nodelay(true).unwrap();
tokio::spawn(handle_connection(stream, state.clone()));
}
// #[instrument(level = "trace"), skip(state)]
// todo - clone before spawn and trace process_stream
async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
let result = process_stream(stream, state).await;
match result {
Ok(_) => tracing::trace!("TcpStream connection closed"),
Err(e) => tracing::error!("TcpStream connection failed: {}", e),
}
}
/// This method is responsible for the internal tcp stream handshake
/// The handshake will specialize the stream as a request/sender or response/receiver stream
async fn process_stream(
stream: tokio::net::TcpStream,
state: Arc<Mutex<State>>,
) -> Result<(), String> {
// split the socket in to a reader and writer
let (read_half, write_half) = tokio::io::split(stream);
// attach the codec to the reader and writer to get framed readers and writers
let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
// the internal tcp [`CallHomeHandshake`] connects the socket to the requester
// here we await this first message as a raw bytes two part message
let first_message = framed_reader
.next()
.await
.ok_or("Connection closed without a ControlMessge".to_string())?
.map_err(|e| e.to_string())?;
// we await on the raw bytes which should come in as a header only message
// todo - improve error handling - check for no data
if first_message.header().is_none() {
return Err("Expected ControlMessage, got DataMessage".to_string());
}
// deserialize the [`CallHomeHandshake`] message
let handshake: CallHomeHandshake = serde_json::from_slice(first_message.header().unwrap())
.map_err(|e| {
format!(
"Failed to deserialize the first message as a valid `CallHomeHandshake`: {}",
e
)
})?;
// branch here to handle sender stream or receiver stream
match handshake.stream_type {
StreamType::Request => process_request_stream().await,
StreamType::Response => {
process_response_stream(handshake.subject, state, framed_reader, framed_writer)
.await
}
}
.map_err(|e| format!("Failed to process stream: {}", e))
}
async fn process_request_stream() -> Result<(), String> {
Ok(())
}
async fn process_response_stream(
subject: String,
state: Arc<Mutex<State>>,
mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
) -> Result<(), String> {
let response_stream = state
.lock().await
.rx_subjects
.remove(&subject)
.ok_or(format!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
// unwrap response_stream
let RequestedRecvConnection {
context,
connection,
} = response_stream;
// the [`Prologue`]
// there must be a second control message it indicate the other segment's generate method was successful
let prologue = reader
.next()
.await
.ok_or("Connection closed without a ControlMessge".to_string())?
.map_err(|e| e.to_string())?;
// deserialize prologue
let prologue = match prologue.into_message_type() {
TwoPartMessageType::HeaderOnly(header) => {
let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
.map_err(|e| format!("Failed to deserialize ControlMessage: {}", e))?;
prologue
}
_ => {
panic!("Expected HeaderOnly ControlMessage; internally logic error")
}
};
// await the control message of GTG or Error, if error, then connection.send(Err(String)), which should fail the
// generate call chain
//
// note: this second control message might be delayed, but the expensive part of setting up the connection
// is both complete and ready for data flow; awaiting here is not a performance hit or problem and it allows
// us to trace the initial setup time vs the time to prologue
if let Some(error) = &prologue.error {
let _ = connection.send(Err(error.clone()));
return Err(format!("Received error prologue: {}", error));
}
// we need to know the buffer size from the registration options; add this to the RequestRecvConnection object
let (tx, rx) = mpsc::channel(16);
if connection
.send(Ok(crate::pipeline::network::StreamReceiver { rx }))
.is_err()
{
return Err("The requester of the stream has been dropped before the connection was established".to_string());
}
let (alive_tx, alive_rx) = mpsc::channel::<()>(1);
let (control_tx, _control_rx) = mpsc::channel::<Bytes>(8);
// monitor task
// if the context is cancelled, we need to forward the message across the transport layer
// we only determine the forwarding task on a kill signal, on a stop signal, we issue the stop signal, then await for the producer
// to naturally close the stream
let monitor_task = tokio::spawn(monitor(writer, context.clone(), alive_tx));
// forward task
let forward_task = tokio::spawn(handle_response_stream(
reader,
tx,
control_tx,
context.clone(),
alive_rx,
));
// check the results of each of the tasks
let (monitor_result, forward_result) = tokio::join!(monitor_task, forward_task);
// if either of the tasks failed, we need to return an error
if let Err(e) = monitor_result {
return Err(format!("Monitor task failed: {}", e));
}
if let Err(e) = forward_result {
return Err(format!("Forward task failed: {}", e));
}
Ok(())
}
async fn handle_response_stream(
mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
response_tx: mpsc::Sender<Bytes>,
control_tx: mpsc::Sender<Bytes>,
context: Arc<dyn AsyncEngineContext>,
alive_rx: mpsc::Receiver<()>,
) -> Result<(), String> {
// loop over reading the tcp stream and checking if the writer is closed
loop {
tokio::select! {
msg = framed_reader.next() => {
match msg {
Some(Ok(msg)) => {
let (header, data) = msg.into_parts();
if !header.is_empty() && (control_tx.send(header).await).is_err() {
tracing::trace!("Control channel closed")
}
if !data.is_empty() {
response_tx.send(data).await.unwrap();
}
}
Some(Err(e)) => {
return Err(format!("Failed to read TwoPartCodec message from TcpStream: {}", e));
}
None => {
tracing::trace!("TcpStream closed naturally");
break;
}
}
}
_ = response_tx.closed() => {
break;
}
_ = context.killed() => { break; }
}
}
drop(alive_rx);
Ok(())
}
#[allow(dead_code)]
async fn handle_control_message(
mut control_rx: mpsc::Receiver<Bytes>,
context: Arc<dyn AsyncEngineContext>,
alive_tx: mpsc::Sender<()>,
) -> Result<(), String> {
loop {
tokio::select! {
msg = control_rx.recv() => {
match msg {
Some(_msg) => {
// handle control message
}
None => {
tracing::trace!("Control channel closed");
break;
}
}
}
_ = context.killed() => {
break;
}
}
}
drop(alive_tx);
Ok(())
}
async fn monitor(
_socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
ctx: Arc<dyn AsyncEngineContext>,
alive_tx: mpsc::Sender<()>,
) {
let alive_tx = alive_tx;
tokio::select! {
_ = ctx.stopped() => {
// send cancellation message
panic!("impl cancellation signal");
}
_ = alive_tx.closed() => {
tracing::trace!("response stream closed naturally")
}
}
let mut framed_writer = _socket_tx;
framed_writer.get_mut().shutdown().await.unwrap();
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! Pipeline Nodes
//!
//! A `ServicePipeline` is a directed graph of nodes where each node defines a behavior for both
//! forward/request path and the backward/response path. The allowed behaviors in each direction
//! are is either a `Source`, or a `Sink`.
//!
//! A `Frontend` is a the start of a graph and is a [`Source`] for the forward path and a [`Sink`] for the
//! backward path.
//!
//! A `Backend` is the end of a graph and is a [`Sink`] for the forward path and a [`Source`] for the
//! backward path.
//!
//! An [`PipelineOperator`] is a node that can transform both the forward and backward paths using the
//! logic supplied by the implementation of an [`Operator`] trait. Because the [`PipelineOperator`] is
//! both a [`Source`] and a [`Sink`] of the forward request path and the backward response path respectively,
//! i.e. it is two sources and two sinks. We can differentiate the two by using the [`PipelineOperator::forward_edge`]
//! and [`PipelineOperator::backward_edge`] methods.
//!
//! - The [`PipelineOperator::forward_edge`] returns a [`PipelineOperatorForwardEdge`] which is a [`Sink`]
//! for incoming/upstream request and a [`Source`] for the downstream request.
//! - The [`PipelineOperator::backward_edge`] returns a [`PipelineOperatorBackwardEdge`] which is a [`Sink`]
//! for the downstream response and a [`Source`] for the upstream response.
//!
//! An `EdgeOperator` currently named [`PipelineNode`] is a node in the graph can transform only a forward
//! or a backward path, but does not transform both.
//!
//! This makes the [`Operator`] a more powerful trait as it can propagate information from the forward
//! path to the backward path. An `EdgeOperator` on the forward path has no visibility into the backward
//! path and therefore, cannot directly influence the backward path.
//!
use std::{
collections::HashMap,
sync::{Arc, Mutex, OnceLock},
};
use super::AsyncEngine;
use async_trait::async_trait;
use tokio::sync::oneshot;
use super::{Data, Error, PipelineError, PipelineIO};
mod sinks;
mod sources;
pub use sinks::{SegmentSink, ServiceBackend};
pub use sources::{SegmentSource, ServiceFrontend};
pub type Service<In, Out> = Arc<ServiceFrontend<In, Out>>;
mod private {
pub struct Token;
}
// todo rename `ServicePipelineExt`
/// A [`Source`] trait defines how data is emitted from a source to a downstream sink
/// over an [`Edge`].
#[async_trait]
pub trait Source<T: PipelineIO>: Data {
async fn on_next(&self, data: T, _: private::Token) -> Result<(), Error>;
fn set_edge(&self, edge: Edge<T>, _: private::Token) -> Result<(), PipelineError>;
fn link<S: Sink<T> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
let edge = Edge::new(sink.clone());
self.set_edge(edge, private::Token)?;
Ok(sink)
}
}
/// A [`Sink`] trait defines how data is received from a source and processed.
#[async_trait]
pub trait Sink<T: PipelineIO>: Data {
async fn on_data(&self, data: T, _: private::Token) -> Result<(), Error>;
}
/// An [`Edge`] is a connection between a [`Source`] and a [`Sink`]. Data flows over an [`Edge`].
pub struct Edge<T: PipelineIO> {
downstream: Arc<dyn Sink<T>>,
}
impl<T: PipelineIO> Edge<T> {
fn new(downstream: Arc<dyn Sink<T>>) -> Self {
Edge { downstream }
}
async fn write(&self, data: T) -> Result<(), Error> {
self.downstream.on_data(data, private::Token).await
}
}
type NodeFn<In, Out> = Box<dyn Fn(In) -> Result<Out, Error> + Send + Sync>;
/// An [`Operator`] is a trait that defines the behavior of how two [`AsyncEngine`] can be chained together.
/// An [`Operator`] is not quite an [`AsyncEngine`] because its generate method requires both the upstream
/// request, but also the downstream [`AsyncEngine`] to which it will pass the transformed request.
/// The [`Operator`] logic must transform the upstream request `UpIn` to the downstream request `DownIn`,
/// then transform the downstream response `DownOut` to the upstream response `UpOut`.
///
/// A [`PipelineOperator`] accepts an [`Operator`] and presents itself as an [`AsyncEngine`] for the upstream
/// [`AsyncEngine<UpIn, UpOut, Error>`].
///
/// ### Example of type transformation and data flow
/// ```text
/// ... --> <UpIn> ---> [Operator] --> <DownIn> ---> ...
/// ... <-- <UpOut> --> [Operator] <-- <DownOut> <-- ...
/// ```
#[async_trait]
pub trait Operator<UpIn: PipelineIO, UpOut: PipelineIO, DownIn: PipelineIO, DownOut: PipelineIO>:
Data
{
/// This method is expected to transform the upstream request `UpIn` to the downstream request `DownIn`,
/// call the next [`AsyncEngine`] with the transformed request, then transform the downstream response
/// `DownOut` to the upstream response `UpOut`.
async fn generate(
&self,
req: UpIn,
next: Arc<dyn AsyncEngine<DownIn, DownOut, Error>>,
) -> Result<UpOut, Error>;
fn into_operator(self: &Arc<Self>) -> Arc<PipelineOperator<UpIn, UpOut, DownIn, DownOut>>
where
Self: Sized,
{
PipelineOperator::new(self.clone())
}
}
/// A [`PipelineOperatorForwardEdge`] is [`Sink`] for the upstream request type `UpIn` and a [`Source`] for the
/// downstream request type `DownIn`.
pub struct PipelineOperatorForwardEdge<
UpIn: PipelineIO,
UpOut: PipelineIO,
DownIn: PipelineIO,
DownOut: PipelineIO,
> {
parent: Arc<PipelineOperator<UpIn, UpOut, DownIn, DownOut>>,
}
/// A [`PipelineOperatorBackwardEdge`] is [`Sink`] for the downstream response type `DownOut` and a [`Source`] for the
/// upstream response type `UpOut`.
pub struct PipelineOperatorBackwardEdge<
UpIn: PipelineIO,
UpOut: PipelineIO,
DownIn: PipelineIO,
DownOut: PipelineIO,
> {
parent: Arc<PipelineOperator<UpIn, UpOut, DownIn, DownOut>>,
}
/// A [`PipelineOperator`] is a node that can transform both the forward and backward paths using the logic defined
/// by the implementation of an [`Operator`] trait.
pub struct PipelineOperator<
UpIn: PipelineIO,
UpOut: PipelineIO,
DownIn: PipelineIO,
DownOut: PipelineIO,
> {
// core business logic of this object
operator: Arc<dyn Operator<UpIn, UpOut, DownIn, DownOut>>,
// this hold the downstream connections via the generic frontend
// frontends provide both a source and a sink interfaces
downstream: Arc<sources::Frontend<DownIn, DownOut>>,
// this hold the connection to the previous/upstream response sink
// we are a source to that upstream's response sink
upstream: sinks::SinkEdge<UpOut>,
}
impl<UpIn, UpOut, DownIn, DownOut> PipelineOperator<UpIn, UpOut, DownIn, DownOut>
where
UpIn: PipelineIO,
UpOut: PipelineIO,
DownIn: PipelineIO,
DownOut: PipelineIO,
{
/// Create a new [`PipelineOperator`] with the given [`Operator`] implementation.
pub fn new(operator: Arc<dyn Operator<UpIn, UpOut, DownIn, DownOut>>) -> Arc<Self> {
Arc::new(PipelineOperator {
operator,
downstream: Arc::new(sources::Frontend::default()),
upstream: sinks::SinkEdge::default(),
})
}
/// Access the forward edge of the [`PipelineOperator`] allowing the forward/requests paths to be linked.
pub fn forward_edge(
self: &Arc<Self>,
) -> Arc<PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut>> {
Arc::new(PipelineOperatorForwardEdge {
parent: self.clone(),
})
}
/// Access the backward edge of the [`PipelineOperator`] allowing the backward/responses paths to be linked.
pub fn backward_edge(
self: &Arc<Self>,
) -> Arc<PipelineOperatorBackwardEdge<UpIn, UpOut, DownIn, DownOut>> {
Arc::new(PipelineOperatorBackwardEdge {
parent: self.clone(),
})
}
}
/// A [`PipelineOperator`] is an [`AsyncEngine`] for the upstream [`AsyncEngine<UpIn, UpOut, Error>`].
#[async_trait]
impl<UpIn, UpOut, DownIn, DownOut> AsyncEngine<UpIn, UpOut, Error>
for PipelineOperator<UpIn, UpOut, DownIn, DownOut>
where
UpIn: PipelineIO,
DownIn: PipelineIO,
DownOut: PipelineIO,
UpOut: PipelineIO,
{
async fn generate(&self, req: UpIn) -> Result<UpOut, Error> {
self.operator.generate(req, self.downstream.clone()).await
}
}
#[async_trait]
impl<UpIn, UpOut, DownIn, DownOut> Sink<UpIn>
for PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut>
where
UpIn: PipelineIO,
DownIn: PipelineIO,
DownOut: PipelineIO,
UpOut: PipelineIO,
{
async fn on_data(&self, data: UpIn, _token: private::Token) -> Result<(), Error> {
let stream = self.parent.generate(data).await?;
self.parent.upstream.on_next(stream, private::Token).await
}
}
#[async_trait]
impl<UpIn, UpOut, DownIn, DownOut> Source<DownIn>
for PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut>
where
UpIn: PipelineIO,
DownIn: PipelineIO,
DownOut: PipelineIO,
UpOut: PipelineIO,
{
async fn on_next(&self, data: DownIn, token: private::Token) -> Result<(), Error> {
self.parent.downstream.on_next(data, token).await
}
fn set_edge(&self, edge: Edge<DownIn>, token: private::Token) -> Result<(), PipelineError> {
self.parent.downstream.set_edge(edge, token)
}
}
#[async_trait]
impl<UpIn, UpOut, DownIn, DownOut> Sink<DownOut>
for PipelineOperatorBackwardEdge<UpIn, UpOut, DownIn, DownOut>
where
UpIn: PipelineIO,
DownIn: PipelineIO,
DownOut: PipelineIO,
UpOut: PipelineIO,
{
async fn on_data(&self, data: DownOut, token: private::Token) -> Result<(), Error> {
self.parent.downstream.on_data(data, token).await
}
}
#[async_trait]
impl<UpIn, UpOut, DownIn, DownOut> Source<UpOut>
for PipelineOperatorBackwardEdge<UpIn, UpOut, DownIn, DownOut>
where
UpIn: PipelineIO,
DownIn: PipelineIO,
DownOut: PipelineIO,
UpOut: PipelineIO,
{
async fn on_next(&self, data: UpOut, token: private::Token) -> Result<(), Error> {
self.parent.upstream.on_next(data, token).await
}
fn set_edge(&self, edge: Edge<UpOut>, token: private::Token) -> Result<(), PipelineError> {
self.parent.upstream.set_edge(edge, token)
}
}
pub struct PipelineNode<In: PipelineIO, Out: PipelineIO> {
edge: OnceLock<Edge<Out>>,
map_fn: NodeFn<In, Out>,
}
impl<In: PipelineIO, Out: PipelineIO> PipelineNode<In, Out> {
pub fn new(map_fn: NodeFn<In, Out>) -> Arc<Self> {
Arc::new(PipelineNode::<In, Out> {
edge: OnceLock::new(),
map_fn,
})
}
}
#[async_trait]
impl<In: PipelineIO, Out: PipelineIO> Source<Out> for PipelineNode<In, Out> {
async fn on_next(&self, data: Out, _: private::Token) -> Result<(), Error> {
self.edge
.get()
.ok_or(PipelineError::NoEdge)?
.write(data)
.await
}
fn set_edge(&self, edge: Edge<Out>, _: private::Token) -> Result<(), PipelineError> {
self.edge
.set(edge)
.map_err(|_| PipelineError::EdgeAlreadySet)?;
Ok(())
}
}
#[async_trait]
impl<In: PipelineIO, Out: PipelineIO> Sink<In> for PipelineNode<In, Out> {
async fn on_data(&self, data: In, _: private::Token) -> Result<(), Error> {
self.on_next((self.map_fn)(data)?, private::Token).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::*;
#[tokio::test]
async fn test_pipeline_source_no_edge() {
let source = ServiceFrontend::<SingleIn<()>, ManyOut<()>>::new();
let stream = source.generate(().into()).await;
assert!(stream.is_err());
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::{
async_trait, private::Token, Arc, Edge, OnceLock, PipelineError, Service, Sink, Source,
};
use crate::pipeline::{PipelineIO, ServiceEngine};
mod base;
mod pipeline;
mod segment;
pub(crate) struct SinkEdge<Resp: PipelineIO> {
edge: OnceLock<Edge<Resp>>,
}
pub struct ServiceBackend<Req: PipelineIO, Resp: PipelineIO> {
engine: ServiceEngine<Req, Resp>,
inner: SinkEdge<Resp>,
}
// todo - use a once lock of a TransportEngine
pub struct SegmentSink<Req: PipelineIO, Resp: PipelineIO> {
engine: OnceLock<ServiceEngine<Req, Resp>>,
inner: SinkEdge<Resp>,
}
#[allow(dead_code)]
pub struct EgressPort<Req: PipelineIO, Resp: PipelineIO> {
engine: Service<Req, Resp>,
}
// impl<Resp: PipelineIO> SegmentSink<Req, Resp> {
// pub connect(&self)
// }
// impl<Req, Resp> EgressPort<Req, Resp>
// where
// Req: PipelineIO + Serialize,
// Resp: for<'de> Deserialize<'de> + DataType,
// {
// }
// #[async_trait]
// impl<Req, Resp> AsyncEngine<Context<Req>, Annotated<Resp>> for EgressPort<Req, Resp>
// where
// Req: PipelineIO + Serialize,
// Resp: for<'de> Deserialize<'de> + DataType,
// {
// async fn generate(&self, request: Context<Req>) -> Result<Resp, GenerateError> {
// // when publish our request, we need to publish it with a subject
// // we will use a trait in the future
// let tx_subject = "tx-model-subject".to_string();
// let rx_subject = "rx-model-subject".to_string();
// // make a response channel
// let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel::<Vec<u8>>(16);
// // register the bytes_tx sender with the response subject
// // let bytes_stream = self.response_subscriber.register(rx_subject, bytes_tx);
// // ask network impl for a Sender to the cancellation channel
// let request = request
// .try_map(|req| bincode::serialize(&req))
// .map_err(|e| {
// GenerateError(format!(
// "Failed to serialize request in egress port: {}",
// e.to_string()
// ))
// })?;
// let (data, context) = request.transfer(());
// let stream_ctx = Arc::new(StreamContext::from(context));
// let shutdown_ctx = stream_ctx.clone();
// let (live_tx, live_rx) = tokio::sync::oneshot::channel::<()>();
// let byte_stream = ReceiverStream::new(bytes_rx);
// let decoded = byte_stream
// // decode the response
// .map(move |item| {
// bincode::deserialize::<Annotated<Resp>>(&item)
// .expect("failed to deserialize response")
// })
// .scan(Some(live_tx), move |live_tx, item| {
// match item {
// Annotated::End => {
// // this essentially drops the channel
// let _ = live_tx.take();
// }
// _ => {}
// }
// futures::future::ready(Some(item))
// });
// return Ok(ResponseStream::new(Box::pin(decoded), stream_ctx));
// }
// }
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::*;
use crate::Error;
impl<Resp: PipelineIO> Default for SinkEdge<Resp> {
fn default() -> Self {
Self {
edge: OnceLock::new(),
}
}
}
#[async_trait]
impl<Resp: PipelineIO> Source<Resp> for SinkEdge<Resp> {
async fn on_next(&self, data: Resp, _: Token) -> Result<(), Error> {
self.edge
.get()
.ok_or(PipelineError::NoEdge)?
.write(data)
.await
}
fn set_edge(&self, edge: Edge<Resp>, _: Token) -> Result<(), PipelineError> {
self.edge
.set(edge)
.map_err(|_| PipelineError::EdgeAlreadySet)?;
Ok(())
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::*;
use crate::Error;
impl<Req: PipelineIO, Resp: PipelineIO> ServiceBackend<Req, Resp> {
pub fn from_engine(engine: ServiceEngine<Req, Resp>) -> Arc<Self> {
Arc::new(Self {
engine,
inner: SinkEdge::default(),
})
}
}
#[async_trait]
impl<Req: PipelineIO, Resp: PipelineIO> Sink<Req> for ServiceBackend<Req, Resp> {
async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> {
let stream = self.engine.generate(data).await?;
self.on_next(stream, Token).await
}
}
#[async_trait]
impl<Req: PipelineIO, Resp: PipelineIO> Source<Resp> for ServiceBackend<Req, Resp> {
async fn on_next(&self, data: Resp, _: Token) -> Result<(), Error> {
self.inner.on_next(data, Token).await
}
fn set_edge(&self, edge: Edge<Resp>, _: Token) -> Result<(), PipelineError> {
self.inner.set_edge(edge, Token)
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::*;
use crate::Error;
impl<Req: PipelineIO, Resp: PipelineIO> SegmentSink<Req, Resp> {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
pub fn attach(&self, engine: ServiceEngine<Req, Resp>) -> Result<(), PipelineError> {
self.engine
.set(engine)
.map_err(|_| PipelineError::EdgeAlreadySet)
}
}
impl<Req: PipelineIO, Resp: PipelineIO> Default for SegmentSink<Req, Resp> {
fn default() -> Self {
Self {
engine: OnceLock::new(),
inner: SinkEdge::default(),
}
}
}
#[async_trait]
impl<Req: PipelineIO, Resp: PipelineIO> Sink<Req> for SegmentSink<Req, Resp> {
async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> {
let stream = self
.engine
.get()
.ok_or(PipelineError::NoNetworkEdge)?
.generate(data)
.await?;
self.on_next(stream, Token).await
}
}
#[async_trait]
impl<Req: PipelineIO, Resp: PipelineIO> Source<Resp> for SegmentSink<Req, Resp> {
async fn on_next(&self, data: Resp, _: Token) -> Result<(), Error> {
self.inner.on_next(data, Token).await
}
fn set_edge(&self, edge: Edge<Resp>, _: Token) -> Result<(), PipelineError> {
self.inner.set_edge(edge, Token)
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::*;
use crate::pipeline::{AsyncEngine, PipelineIO};
mod base;
mod common;
pub struct Frontend<In: PipelineIO, Out: PipelineIO> {
edge: OnceLock<Edge<In>>,
sinks: Arc<Mutex<HashMap<String, oneshot::Sender<Out>>>>,
}
/// A [`ServiceFrontend`] is the interface for an [`AsyncEngine<SingleIn<Context<In>>, ManyOut<Annotated<Out>>, Error>`]
pub struct ServiceFrontend<In: PipelineIO, Out: PipelineIO> {
inner: Frontend<In, Out>,
}
pub struct SegmentSource<In: PipelineIO, Out: PipelineIO> {
inner: Frontend<In, Out>,
}
// impl<In: DataType, Out: PipelineIO> Frontend<In, Out> {
// pub fn new() -> Arc<Self> {
// Arc::new(Self {
// edge: OnceLock::new(),
// sinks: Arc::new(Mutex::new(HashMap::new())),
// })
// }
// }
// impl<In: DataType, Out: PipelineIO> SegmentSource<In, Out> {
// pub fn new() -> Arc<Self> {
// Arc::new(Self {
// edge: OnceLock::new(),
// sinks: Arc::new(Mutex::new(HashMap::new())),
// })
// }
// }
// #[async_trait]
// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for Frontend<In, Out> {
// async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
// self.edge
// .get()
// .ok_or(PipelineError::NoEdge)?
// .write(data)
// .await
// }
// fn set_edge(
// &self,
// edge: Edge<Context<In>>>,
// _: private::Token,
// ) -> Result<(), PipelineError> {
// self.edge
// .set(edge)
// .map_err(|_| PipelineError::EdgeAlreadySet)?;
// Ok(())
// }
// }
// #[async_trait]
// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for Frontend<In, Out> {
// async fn on_data(
// &self,
// data: PipelineStream<Out>,
// _: private::Token,
// ) -> Result<(), PipelineError> {
// let context = data.context();
// let mut sinks = self.sinks.lock().unwrap();
// let tx = sinks
// .remove(context.id())
// .ok_or(PipelineError::DetatchedStreamReceiver)
// .map_err(|e| {
// data.context().stop_generating();
// e
// })?;
// drop(sinks);
// let ctx = data.context();
// tx.send(data)
// .map_err(|_| PipelineError::DetatchedStreamReceiver)
// .map_err(|e| {
// ctx.stop_generating();
// e
// })
// }
// }
// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for Frontend<In, Out> {
// fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
// let edge = Edge::new(sink.clone());
// self.set_edge(edge.into(), private::Token {})?;
// Ok(sink)
// }
// }
// #[async_trait]
// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
// for Frontend<In, Out>
// {
// async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
// let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
// {
// let mut sinks = self.sinks.lock().unwrap();
// sinks.insert(request.id().to_string(), tx);
// }
// self.on_next(request, private::Token {}).await?;
// rx.await.map_err(|_| PipelineError::DetatchedStreamSender)
// }
// }
// // SegmentSource
// #[async_trait]
// impl<In: DataType, Out: PipelineIO> Source<Context<In>> for SegmentSource<In, Out> {
// async fn on_next(&self, data: Context<In>, _: private::Token) -> Result<(), PipelineError> {
// self.edge
// .get()
// .ok_or(PipelineError::NoEdge)?
// .write(data)
// .await
// }
// fn set_edge(
// &self,
// edge: Edge<Context<In>>>,
// _: private::Token,
// ) -> Result<(), PipelineError> {
// self.edge
// .set(edge)
// .map_err(|_| PipelineError::EdgeAlreadySet)?;
// Ok(())
// }
// }
// #[async_trait]
// impl<In: DataType, Out: PipelineIO> Sink<PipelineStream<Out>> for SegmentSource<In, Out> {
// async fn on_data(
// &self,
// data: PipelineStream<Out>,
// _: private::Token,
// ) -> Result<(), PipelineError> {
// let context = data.context();
// let mut sinks = self.sinks.lock().unwrap();
// let tx = sinks
// .remove(context.id())
// .ok_or(PipelineError::DetatchedStreamReceiver)
// .map_err(|e| {
// data.context().stop_generating();
// e
// })?;
// drop(sinks);
// let ctx = data.context();
// tx.send(data)
// .map_err(|_| PipelineError::DetatchedStreamReceiver)
// .map_err(|e| {
// ctx.stop_generating();
// e
// })
// }
// }
// impl<In: DataType, Out: PipelineIO> Link<Context<In>> for SegmentSource<In, Out> {
// fn link<S: Sink<Context<In>> + 'static>(&self, sink: Arc<S>) -> Result<Arc<S>, PipelineError> {
// let edge = Edge::new(sink.clone());
// self.set_edge(edge.into(), private::Token {})?;
// Ok(sink)
// }
// }
// #[async_trait]
// impl<In: DataType, Out: PipelineIO> AsyncEngine<Context<In>, Annotated<Out>, PipelineError>
// for SegmentSource<In, Out>
// {
// async fn generate(&self, request: Context<In>) -> Result<PipelineStream<Out>, PipelineError> {
// let (tx, rx) = oneshot::channel::<PipelineStream<Out>>();
// {
// let mut sinks = self.sinks.lock().unwrap();
// sinks.insert(request.id().to_string(), tx);
// }
// self.on_next(request, private::Token {}).await?;
// rx.await.map_err(|_| PipelineError::DetatchedStreamSender)
// }
// }
// #[cfg(test)]
// mod tests {
// use super::*;
// #[tokio::test]
// async fn test_pipeline_source_no_edge() {
// let source = Frontend::<(), ()>::new();
// let stream = source.generate(().into()).await;
// match stream {
// Err(PipelineError::NoEdge) => (),
// _ => panic!("Expected NoEdge error"),
// }
// }
// }
// pub struct IngressPort<In, Out: PipelineIO> {
// edge: OnceLock<ServiceEngine<In, Out>>,
// }
// impl<In, Out> IngressPort<In, Out>
// where
// In: for<'de> Deserialize<'de> + DataType,
// Out: PipelineIO + Serialize,
// {
// pub fn new() -> Arc<Self> {
// Arc::new(IngressPort {
// edge: OnceLock::new(),
// })
// }
// }
// #[async_trait]
// impl<In, Out> AsyncEngine<Context<Vec<u8>>, Vec<u8>> for IngressPort<In, Out>
// where
// In: for<'de> Deserialize<'de> + DataType,
// Out: PipelineIO + Serialize,
// {
// async fn generate(
// &self,
// request: Context<Vec<u8>>,
// ) -> Result<EngineStream<Vec<u8>>, PipelineError> {
// // Deserialize request
// let request = request.try_map(|bytes| {
// bincode::deserialize::<In>(&bytes)
// .map_err(|err| PipelineError(format!("Failed to deserialize request: {}", err)))
// })?;
// // Forward request to edge
// let stream = self
// .edge
// .get()
// .ok_or(PipelineError("No engine to forward request to".to_string()))?
// .generate(request)
// .await?;
// // Serialize response stream
// let stream =
// stream.map(|resp| bincode::serialize(&resp).expect("Failed to serialize response"));
// Err(PipelineError(format!("Not implemented")))
// }
// }
// fn convert_stream<T, U>(
// stream: impl Stream<Item = ServerStream<T>> + Send + 'static,
// ctx: Arc<dyn AsyncEngineContext>,
// transform: Arc<dyn Fn(T) -> Result<U, StreamError> + Send + Sync>,
// ) -> Pin<Box<dyn Stream<Item = ServerStream<U>> + Send>>
// where
// T: Send + 'static,
// U: Send + 'static,
// {
// Box::pin(stream.flat_map(move |item| {
// let ctx = ctx.clone();
// let transform = transform.clone();
// match item {
// ServerStream::Data(data) => match transform(data) {
// Ok(transformed) => futures::stream::iter(vec![ServerStream::Data(transformed)]),
// Err(e) => {
// // Trigger cancellation and propagate the error, followed by Sentinel
// ctx.stop_generating();
// futures::stream::iter(vec![ServerStream::Error(e), ServerStream::Sentinel])
// }
// },
// other => futures::stream::iter(vec![other]),
// }
// })
// // Use take_while to stop processing when encountering the Sentinel
// .take_while(|item| futures::future::ready(!matches!(item, ServerStream::Sentinel))))
// }
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use crate::engine::AsyncEngineContextProvider;
use super::*;
impl<In: PipelineIO, Out: PipelineIO> Default for Frontend<In, Out> {
fn default() -> Self {
Self {
edge: OnceLock::new(),
sinks: Arc::new(Mutex::new(HashMap::new())),
}
}
}
#[async_trait]
impl<In: PipelineIO, Out: PipelineIO> Source<In> for Frontend<In, Out> {
async fn on_next(&self, data: In, _: private::Token) -> Result<(), Error> {
self.edge
.get()
.ok_or(PipelineError::NoEdge)?
.write(data)
.await
}
fn set_edge(&self, edge: Edge<In>, _: private::Token) -> Result<(), PipelineError> {
self.edge
.set(edge)
.map_err(|_| PipelineError::EdgeAlreadySet)?;
Ok(())
}
}
#[async_trait]
impl<In: PipelineIO, Out: PipelineIO + AsyncEngineContextProvider> Sink<Out> for Frontend<In, Out> {
async fn on_data(&self, data: Out, _: private::Token) -> Result<(), Error> {
let ctx = data.context();
let mut sinks = self.sinks.lock().unwrap();
let tx = sinks
.remove(ctx.id())
.ok_or(PipelineError::DetatchedStreamReceiver)
.inspect_err(|_| {
ctx.stop_generating();
})?;
drop(sinks);
Ok(tx
.send(data)
.map_err(|_| PipelineError::DetatchedStreamReceiver)
.inspect_err(|_| {
ctx.stop_generating();
})?)
}
}
#[async_trait]
impl<In: PipelineIO, Out: PipelineIO> AsyncEngine<In, Out, Error> for Frontend<In, Out> {
async fn generate(&self, request: In) -> Result<Out, Error> {
let (tx, rx) = oneshot::channel::<Out>();
{
let mut sinks = self.sinks.lock().unwrap();
sinks.insert(request.id().to_string(), tx);
}
self.on_next(request, private::Token {}).await?;
Ok(rx.await.map_err(|_| PipelineError::DetatchedStreamSender)?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::{error::PipelineErrorExt, ManyOut, SingleIn};
#[tokio::test]
async fn test_frontend_no_edge() {
let source = Frontend::<SingleIn<()>, ManyOut<()>>::default();
let error = source
.generate(().into())
.await
.unwrap_err()
.try_into_pipeline_error()
.unwrap();
match error {
PipelineError::NoEdge => (),
_ => panic!("Expected NoEdge error"),
}
let result = source
.on_next(().into(), private::Token)
.await
.unwrap_err()
.try_into_pipeline_error()
.unwrap();
match result {
PipelineError::NoEdge => (),
_ => panic!("Expected NoEdge error"),
}
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use crate::engine::AsyncEngineContextProvider;
use super::*;
macro_rules! impl_frontend {
($type:ident) => {
impl<In: PipelineIO, Out: PipelineIO> $type<In, Out> {
pub fn new() -> Arc<Self> {
Arc::new(Self {
inner: Frontend::default(),
})
}
}
#[async_trait]
impl<In: PipelineIO, Out: PipelineIO> Source<In> for $type<In, Out> {
async fn on_next(&self, data: In, token: private::Token) -> Result<(), Error> {
self.inner.on_next(data, token).await
}
fn set_edge(&self, edge: Edge<In>, token: private::Token) -> Result<(), PipelineError> {
self.inner.set_edge(edge, token)
}
}
#[async_trait]
impl<In: PipelineIO, Out: PipelineIO + AsyncEngineContextProvider> Sink<Out>
for $type<In, Out>
{
async fn on_data(&self, data: Out, token: private::Token) -> Result<(), Error> {
self.inner.on_data(data, token).await
}
}
#[async_trait]
impl<In: PipelineIO, Out: PipelineIO> AsyncEngine<In, Out, Error> for $type<In, Out> {
async fn generate(&self, request: In) -> Result<Out, Error> {
self.inner.generate(request).await
}
}
};
}
impl_frontend!(ServiceFrontend);
impl_frontend!(SegmentSource);
#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::{ManyOut, PipelineErrorExt, SingleIn};
#[tokio::test]
async fn test_pipeline_source_no_edge() {
let source = Frontend::<SingleIn<()>, ManyOut<()>>::default();
let stream = source
.generate(().into())
.await
.unwrap_err()
.try_into_pipeline_error()
.unwrap();
match stream {
PipelineError::NoEdge => (),
_ => panic!("Expected NoEdge error"),
}
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
/// Registry struct that manages both shared and unique objects.
///
/// # Examples
///
/// ```
/// use triton_distributed::pipeline::registry::Registry;
///
/// let mut registry = Registry::new();
///
/// // Insert and retrieve shared objects
/// registry.insert_shared("shared1", 42);
/// assert_eq!(*registry.get_shared::<i32>("shared1").unwrap(), 42);
///
/// // Insert and take unique objects
/// registry.insert_unique("unique1", "Hello".to_string());
/// assert_eq!(registry.take_unique::<String>("unique1").unwrap(), "Hello");
///
/// // Taking the same unique again should fail since it's not cloneable
/// assert!(registry.take_unique::<String>("unique1").is_err());
///
/// // Insert and clone unique objects
/// registry.insert_unique("unique2", "World".to_string());
/// assert_eq!(registry.clone_unique::<String>("unique2").unwrap(), "World");
///
/// // Taking the same cloned unique should is ok
/// assert!(registry.take_unique::<String>("unique2").is_ok());
///
/// ```
#[derive(Debug, Default)]
pub struct Registry {
shared_storage: HashMap<String, Arc<dyn Any + Send + Sync>>, // Shared objects
unique_storage: HashMap<String, Box<dyn Any + Send + Sync>>, // Takable objects
}
impl Registry {
/// Create a new empty registry.
pub fn new() -> Self {
Registry {
shared_storage: HashMap::new(),
unique_storage: HashMap::new(),
}
}
/// Check if a shared object exists in the registry by key.
pub fn contains_shared(&self, key: &str) -> bool {
self.shared_storage.contains_key(key)
}
/// Insert a shared object into the registry with a specific key.
pub fn insert_shared<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
self.shared_storage.insert(
key.to_string(),
Arc::new(value) as Arc<dyn Any + Send + Sync>,
);
}
/// Retrieve a shared object from the registry by key and type.
pub fn get_shared<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
match self.shared_storage.get(key) {
Some(boxed) => boxed.clone().downcast::<V>().map_err(|_| {
format!(
"Failed to downcast to the requested type for shared key: {}",
key
)
}),
None => Err(format!("Shared key not found: {}", key)),
}
}
/// Check if a unique object exists in the registry by key.
pub fn contains_unique(&self, key: &str) -> bool {
self.unique_storage.contains_key(key)
}
/// Insert a unique object into the registry with a specific key.
pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
self.unique_storage.insert(
key.to_string(),
Box::new(value) as Box<dyn Any + Send + Sync>,
);
}
/// Take a unique object from the registry by key and type, removing it from the registry.
pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
match self.unique_storage.remove(key) {
Some(boxed) => boxed.downcast::<V>().map(|b| *b).map_err(|_| {
format!(
"Failed to downcast to the requested type for unique key: {}",
key
)
}),
None => Err(format!("Takable key not found: {}", key)),
}
}
/// Clone a unique object from the registry if it implements `Clone`.
pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
match self.unique_storage.get(key) {
Some(boxed) => boxed.downcast_ref::<V>().cloned().ok_or_else(|| {
format!(
"Failed to downcast to the requested type for unique key: {}",
key
)
}),
None => Err(format!("Takable key not found: {}", key)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_and_get_shared() {
let mut registry = Registry::new();
registry.insert_shared("shared1", 42);
assert_eq!(*registry.get_shared::<i32>("shared1").unwrap(), 42);
assert!(registry.get_shared::<f64>("shared1").is_err()); // Testing a downcast failure
}
#[test]
fn test_insert_and_take_unique() {
let mut registry = Registry::new();
registry.insert_unique("unique1", "Hello".to_string());
assert_eq!(registry.take_unique::<String>("unique1").unwrap(), "Hello");
assert!(registry.take_unique::<String>("unique1").is_err()); // Key is now missing
}
#[test]
fn test_insert_and_clone_then_take_unique() {
let mut registry = Registry::new();
registry.insert_unique("unique2", "World".to_string());
assert_eq!(registry.clone_unique::<String>("unique2").unwrap(), "World");
// When cloned, the object should still be available for taking
assert!(registry.take_unique::<String>("unique2").is_ok());
}
#[test]
fn test_failed_take_after_cloning() {
let mut registry = Registry::new();
registry.insert_unique("unique3", "Another".to_string());
assert_eq!(
registry.clone_unique::<String>("unique3").unwrap(),
"Another"
);
// Cloned, then Take is OK
assert_eq!(
registry.take_unique::<String>("unique3").unwrap(),
"Another"
);
// Take, then Take again should fail
assert!(registry.take_unique::<String>("unique3").is_err());
}
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use serde::{Deserialize, Serialize};
pub mod annotated;
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Component {
pub name: String,
pub namespace: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Endpoint {
pub name: String,
pub component: String,
pub namespace: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum RouterType {
PushRoundRobin,
PushRandom,
}
impl Default for RouterType {
fn default() -> Self {
Self::PushRandom
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelMetaData {
pub name: String,
pub component: Component,
pub router_type: RouterType,
}
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use super::*;
use crate::{error, Result};
pub trait AnnotationsProvider {
fn annotations(&self) -> Option<Vec<String>>;
fn has_annotation(&self, annotation: &str) -> bool {
self.annotations()
.map(|annotations| annotations.iter().any(|a| a == annotation))
.unwrap_or(false)
}
}
/// Our services have the option of returning an "annotated" stream, which allows use
/// to include additional information with each delta. This is useful for debugging,
/// performance benchmarking, and improved observability.
#[derive(Serialize, Deserialize, Debug)]
pub struct Annotated<R> {
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<R>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub comment: Option<Vec<String>>,
}
impl<R> Annotated<R> {
/// Create a new annotated stream from the given error
pub fn from_error(error: String) -> Self {
Self {
data: None,
id: None,
event: Some("error".to_string()),
comment: Some(vec![error]),
}
}
/// Create a new annotated stream from the given data
pub fn from_data(data: R) -> Self {
Self {
data: Some(data),
id: None,
event: None,
comment: None,
}
}
/// Add an annotation to the stream
///
/// Annotations populate the `event` field and the `comment` field
pub fn from_annotation<S: Serialize>(
name: impl Into<String>,
value: &S,
) -> Result<Self, serde_json::Error> {
Ok(Self {
data: None,
id: None,
event: Some(name.into()),
comment: Some(vec![serde_json::to_string(value)?]),
})
}
/// Convert to a [`Result<Self, String>`]
/// If [`Self::event`] is "error", return an error message(s) held by [`Self::comment`]
pub fn ok(self) -> Result<Self, String> {
if let Some(event) = &self.event {
if event == "error" {
return Err(self
.comment
.unwrap_or(vec!["unknown error".to_string()])
.join(", "));
}
}
Ok(self)
}
pub fn is_ok(&self) -> bool {
self.event.as_deref() != Some("error")
}
pub fn is_err(&self) -> bool {
!self.is_ok()
}
pub fn is_event(&self) -> bool {
self.event.is_some()
}
pub fn transfer<U: Serialize>(self, data: Option<U>) -> Annotated<U> {
Annotated::<U> {
data,
id: self.id,
event: self.event,
comment: self.comment,
}
}
/// Apply a mapping/transformation to the data field
/// If the mapping fails, the error is returned as an annotated stream
pub fn map_data<U, F>(self, transform: F) -> Annotated<U>
where
F: FnOnce(R) -> Result<U, String>,
{
match self.data.map(transform).transpose() {
Ok(data) => Annotated::<U> {
data,
id: self.id,
event: self.event,
comment: self.comment,
},
Err(e) => Annotated::from_error(e),
}
}
pub fn is_error(&self) -> bool {
self.event.as_deref() == Some("error")
}
pub fn into_result(self) -> Result<Option<R>> {
match self.data {
Some(data) => Ok(Some(data)),
None => match self.event {
Some(event) if event == "error" => Err(error!(self
.comment
.unwrap_or(vec!["unknown error".to_string()])
.join(", ")))?,
_ => Ok(None),
},
}
}
}
// impl<R> Annotated<R>
// where
// R: for<'de> Deserialize<'de> + Serialize,
// {
// pub fn convert_sse_stream(
// stream: DataStream<Result<Message, SseCodecError>>,
// ) -> DataStream<Annotated<R>> {
// let stream = stream.map(|message| match message {
// Ok(message) => {
// let delta = Annotated::<R>::try_from(message);
// match delta {
// Ok(delta) => delta,
// Err(e) => Annotated::from_error(e.to_string()),
// }
// }
// Err(e) => Annotated::from_error(e.to_string()),
// });
// Box::pin(stream)
// }
// }
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! The [Runtime] module is the interface for [crate::component::Component][crate::component::Component]
//! to access shared resources. These include thread pool, memory allocators and other shared resources.
//!
//! The [Runtime] holds the primary [`CancellationToken`] which can be used to terminate all attached
//! [crate::component::Component][crate::component::Component].
//!
//! We expect in the future to offer topologically aware thread and memory resources, but for now the
//! set of resources is limited to the thread pool and cancellation token.
//!
//! Notes: We will need to do an evaluation on what is fully public, what is pub(crate) and what is
//! private; however, for now we are exposing most objects as fully public while the API is maturing.
use super::{error, log, Result, Runtime, RuntimeType};
use crate::config::{self, RuntimeConfig};
use futures::Future;
use once_cell::sync::OnceCell;
use std::sync::{Arc, Mutex};
use tokio::{signal, task::JoinHandle};
pub use tokio_util::sync::CancellationToken;
impl Runtime {
fn new(runtime: RuntimeType) -> Result<Runtime> {
// worker id
let id = Arc::new(uuid::Uuid::new_v4().to_string());
// create a cancellation token
let cancellation_token = CancellationToken::new();
// secondary runtime for background ectd/nats tasks
let secondary = RuntimeConfig::single_threaded().create_runtime()?;
Ok(Runtime {
id,
primary: runtime,
secondary: Arc::new(secondary),
cancellation_token,
})
}
pub fn from_handle(handle: tokio::runtime::Handle) -> Result<Runtime> {
let runtime = RuntimeType::External(handle);
Runtime::new(runtime)
}
/// Create a [`Runtime`] instance from the settings
/// See [`config::RuntimeConfig::from_settings`]
pub fn from_settings() -> Result<Runtime> {
let config = config::RuntimeConfig::from_settings()?;
let owned = RuntimeType::Shared(Arc::new(config.create_runtime()?));
Runtime::new(owned)
}
/// Create a [`Runtime`] with a single-threaded primary async tokio runtime
pub fn single_threaded() -> Result<Runtime> {
let config = config::RuntimeConfig::single_threaded();
let owned = RuntimeType::Shared(Arc::new(config.create_runtime()?));
Runtime::new(owned)
}
/// Returns the unique identifier for the [`Runtime`]
pub fn id(&self) -> &str {
&self.id
}
/// Returns a [`tokio::runtime::Handle`] for the primary/application thread pool
pub fn primary(&self) -> tokio::runtime::Handle {
self.primary.handle()
}
/// Returns a [`tokio::runtime::Handle`] for the secondary/background thread pool
pub fn secondary(&self) -> &Arc<tokio::runtime::Runtime> {
&self.secondary
}
/// Access the primary [`CancellationToken`] for the [`Runtime`]
pub fn primary_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
/// Creates a child [`CancellationToken`] tied to the life-cycle of the [`Runtime`]'s root [`CancellationToken::child_token`] method.
pub fn child_token(&self) -> CancellationToken {
self.cancellation_token.child_token()
}
/// Shuts down the [`Runtime`] instance
pub fn shutdown(&self) {
self.cancellation_token.cancel();
}
}
impl RuntimeType {
/// Get [`tokio::runtime::Handle`] to runtime
pub fn handle(&self) -> tokio::runtime::Handle {
match self {
RuntimeType::External(rt) => rt.clone(),
RuntimeType::Shared(rt) => rt.handle().clone(),
}
}
}
impl std::fmt::Debug for RuntimeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RuntimeType::External(_) => write!(f, "RuntimeType::External"),
RuntimeType::Shared(_) => write!(f, "RuntimeType::Shared"),
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment