"docs/vscode:/vscode.git/clone" did not exist on "61889a14fb0ef80dfd688d7e8da3fd91943b43da"
Unverified Commit 06b0ebef authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: transport agnostic request plane for dynamo - natless (#4246)

parent 381c428c
......@@ -414,6 +414,30 @@ pub fn inject_current_trace_into_nats_headers(headers: &mut async_nats::HeaderMa
inject_otel_context_into_nats_headers(headers, None);
}
// Inject trace headers into a generic HashMap for HTTP/TCP transports
pub fn inject_trace_headers_into_map(headers: &mut std::collections::HashMap<String, String>) {
if let Some(trace_context) = get_distributed_tracing_context() {
// Inject W3C traceparent header
headers.insert(
"traceparent".to_string(),
trace_context.create_traceparent(),
);
// Inject optional tracestate
if let Some(tracestate) = trace_context.tracestate {
headers.insert("tracestate".to_string(), tracestate);
}
// Inject custom request IDs
if let Some(x_request_id) = trace_context.x_request_id {
headers.insert("x-request-id".to_string(), x_request_id);
}
if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
headers.insert("x-dynamo-request-id".to_string(), x_dynamo_request_id);
}
}
}
/// Create a client_request span linked to the parent trace context
pub fn make_client_request_span(
operation: &str,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! TODO - we need to reconcile what is in this crate with distributed::transports
//! Network layer for distributed communication
//!
//! Provides request distribution across multiple transport protocols:
//! - HTTP/2 for standard deployments
//! - TCP with length-prefixed protocol for high-performance scenarios
//! - NATS for legacy/messaging-based deployments
pub mod codec;
pub mod egress;
pub mod ingress;
pub mod manager;
pub mod tcp;
use crate::SystemHealth;
......
......@@ -8,6 +8,7 @@
//! In this module, we define three primary codec used to issue single, two-part or multi-part messages,
//! on a byte stream.
use bytes::Bytes;
use tokio_util::{
bytes::{Buf, BufMut, BytesMut},
codec::{Decoder, Encoder},
......@@ -17,50 +18,709 @@ mod two_part;
pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
// // Custom codec that reads a u64 length header and the message of that length
// #[derive(Default)]
// pub struct LengthPrefixedCodec;
// impl LengthPrefixedCodec {
// pub fn new() -> Self {
// LengthPrefixedCodec {}
// }
// }
// impl Decoder for LengthPrefixedCodec {
// type Item = Vec<u8>;
// type Error = tokio::io::Error;
// fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// // Check if enough bytes are available to read the length (u64 = 8 bytes)
// if src.len() < 8 {
// return Ok(None); // Not enough data to read the length
// }
// // Read the u64 length header
// let len = src.get_u64() as usize;
// // Check if enough bytes are available to read the full message
// if src.len() < len {
// src.reserve(len - src.len()); // Reserve space for the remaining bytes
// return Ok(None);
// }
// // Read the actual message bytes of the specified length
// let data = src.split_to(len).to_vec();
// Ok(Some(data))
// }
// }
// impl Encoder<Vec<u8>> for LengthPrefixedCodec {
// type Error = tokio::io::Error;
// fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
// // Write the length of the message as a u64 header
// dst.put_u64(item.len() as u64);
// // Write the actual message bytes
// dst.put_slice(&item);
// Ok(())
// }
// }
/// TCP request plane protocol message with endpoint routing
///
/// Wire format:
/// - endpoint_path_len: u16 (big-endian)
/// - endpoint_path: UTF-8 string
/// - payload_len: u32 (big-endian)
/// - payload: bytes
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TcpRequestMessage {
pub endpoint_path: String,
pub payload: Bytes,
}
impl TcpRequestMessage {
pub fn new(endpoint_path: String, payload: Bytes) -> Self {
Self {
endpoint_path,
payload,
}
}
/// Encode message to bytes
pub fn encode(&self) -> Result<Bytes, std::io::Error> {
let endpoint_bytes = self.endpoint_path.as_bytes();
let endpoint_len = endpoint_bytes.len();
if endpoint_len > u16::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Endpoint path too long: {} bytes", endpoint_len),
));
}
if self.payload.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Payload too large: {} bytes", self.payload.len()),
));
}
// Use BytesMut for efficient buffer building
let mut buf = BytesMut::with_capacity(2 + endpoint_len + 4 + self.payload.len());
// Write endpoint path length (2 bytes)
buf.put_u16(endpoint_len as u16);
// Write endpoint path
buf.put_slice(endpoint_bytes);
// Write payload length (4 bytes)
buf.put_u32(self.payload.len() as u32);
// Write payload
buf.put_slice(&self.payload);
// Zero-copy conversion to Bytes
Ok(buf.freeze())
}
/// Decode message from bytes (for backward compatibility, zero-copy when possible)
pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
if bytes.len() < 2 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for endpoint path length",
));
}
// Read endpoint path length (2 bytes)
let endpoint_len = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
let mut offset = 2;
if bytes.len() < offset + endpoint_len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for endpoint path",
));
}
// Read endpoint path (requires copy for UTF-8 validation)
let endpoint_path = String::from_utf8(bytes[offset..offset + endpoint_len].to_vec())
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Invalid UTF-8: {}", e),
)
})?;
offset += endpoint_len;
if bytes.len() < offset + 4 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for payload length",
));
}
// Read payload length (4 bytes)
let payload_len = u32::from_be_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
]) as usize;
offset += 4;
if bytes.len() < offset + payload_len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!(
"Not enough bytes for payload: expected {}, got {}",
payload_len,
bytes.len() - offset
),
));
}
// Read payload (zero-copy slice)
let payload = bytes.slice(offset..offset + payload_len);
Ok(Self {
endpoint_path,
payload,
})
}
}
/// Codec for encoding/decoding TcpRequestMessage
/// Supports max_message_size enforcement
#[derive(Clone, Default)]
pub struct TcpRequestCodec {
max_message_size: Option<usize>,
}
impl TcpRequestCodec {
pub fn new(max_message_size: Option<usize>) -> Self {
Self { max_message_size }
}
}
impl Decoder for TcpRequestCodec {
type Item = TcpRequestMessage;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// Need at least 2 bytes for endpoint_path_len
if src.len() < 2 {
return Ok(None);
}
// Peek at endpoint path length without consuming
let endpoint_len = u16::from_be_bytes([src[0], src[1]]) as usize;
let header_size = 2 + endpoint_len + 4; // path_len + path + payload_len
if src.len() < header_size {
return Ok(None);
}
// Peek at payload length
let payload_len_offset = 2 + endpoint_len;
let payload_len = u32::from_be_bytes([
src[payload_len_offset],
src[payload_len_offset + 1],
src[payload_len_offset + 2],
src[payload_len_offset + 3],
]) as usize;
let total_len = header_size + payload_len;
// Check max message size
if let Some(max_size) = self.max_message_size
&& total_len > max_size
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Request too large: {} bytes (max: {} bytes)",
total_len, max_size
),
));
}
// Check if we have the full message
if src.len() < total_len {
return Ok(None);
}
// We have a complete message, advance past length prefix
src.advance(2);
// Read endpoint path
let endpoint_bytes = src.split_to(endpoint_len);
let endpoint_path = String::from_utf8(endpoint_bytes.to_vec()).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Invalid UTF-8 in endpoint path: {}", e),
)
})?;
// Advance past payload length
src.advance(4);
// Read payload
let payload = src.split_to(payload_len).freeze();
Ok(Some(TcpRequestMessage {
endpoint_path,
payload,
}))
}
}
impl Encoder<TcpRequestMessage> for TcpRequestCodec {
type Error = std::io::Error;
fn encode(&mut self, item: TcpRequestMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
let endpoint_bytes = item.endpoint_path.as_bytes();
let endpoint_len = endpoint_bytes.len();
if endpoint_len > u16::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Endpoint path too long: {} bytes", endpoint_len),
));
}
if item.payload.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Payload too large: {} bytes", item.payload.len()),
));
}
let total_len = 2 + endpoint_len + 4 + item.payload.len();
// Check max message size
if let Some(max_size) = self.max_message_size
&& total_len > max_size
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"Request too large: {} bytes (max: {} bytes)",
total_len, max_size
),
));
}
// Reserve space
dst.reserve(total_len);
// Write endpoint path length
dst.put_u16(endpoint_len as u16);
// Write endpoint path
dst.put_slice(endpoint_bytes);
// Write payload length
dst.put_u32(item.payload.len() as u32);
// Write payload
dst.put_slice(&item.payload);
Ok(())
}
}
/// TCP response message (acknowledgment or error)
///
/// Wire format:
/// - length: u32 (big-endian)
/// - data: bytes
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TcpResponseMessage {
pub data: Bytes,
}
impl TcpResponseMessage {
pub fn new(data: Bytes) -> Self {
Self { data }
}
pub fn empty() -> Self {
Self { data: Bytes::new() }
}
/// Encode response to bytes (for backward compatibility)
pub fn encode(&self) -> Result<Bytes, std::io::Error> {
if self.data.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Response too large: {} bytes", self.data.len()),
));
}
// Use BytesMut for efficient buffer building
let mut buf = BytesMut::with_capacity(4 + self.data.len());
// Write length (4 bytes)
buf.put_u32(self.data.len() as u32);
// Write data
buf.put_slice(&self.data);
// Zero-copy conversion to Bytes
Ok(buf.freeze())
}
/// Decode response from bytes (for backward compatibility, zero-copy when possible)
pub fn decode(bytes: &Bytes) -> Result<Self, std::io::Error> {
if bytes.len() < 4 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for response length",
));
}
// Read length (4 bytes)
let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
if bytes.len() < 4 + len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!(
"Not enough bytes for response: expected {}, got {}",
len,
bytes.len() - 4
),
));
}
// Read data (zero-copy slice)
let data = bytes.slice(4..4 + len);
Ok(Self { data })
}
}
/// Codec for encoding/decoding TcpResponseMessage
/// Supports max_message_size enforcement
#[derive(Clone, Default)]
pub struct TcpResponseCodec {
max_message_size: Option<usize>,
}
impl TcpResponseCodec {
pub fn new(max_message_size: Option<usize>) -> Self {
Self { max_message_size }
}
}
impl Decoder for TcpResponseCodec {
type Item = TcpResponseMessage;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// Need at least 4 bytes for length
if src.len() < 4 {
return Ok(None);
}
// Peek at message length without consuming
let data_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
let total_len = 4 + data_len;
// Check max message size
if let Some(max_size) = self.max_message_size
&& total_len > max_size
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Response too large: {} bytes (max: {} bytes)",
total_len, max_size
),
));
}
// Check if we have the full message
if src.len() < total_len {
return Ok(None);
}
// Advance past the length prefix
src.advance(4);
// Read data
let data = src.split_to(data_len).freeze();
Ok(Some(TcpResponseMessage { data }))
}
}
impl Encoder<TcpResponseMessage> for TcpResponseCodec {
type Error = std::io::Error;
fn encode(&mut self, item: TcpResponseMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
if item.data.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Response too large: {} bytes", item.data.len()),
));
}
let total_len = 4 + item.data.len();
// Check max message size
if let Some(max_size) = self.max_message_size
&& total_len > max_size
{
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"Response too large: {} bytes (max: {} bytes)",
total_len, max_size
),
));
}
// Reserve space
dst.reserve(total_len);
// Write length
dst.put_u32(item.data.len() as u32);
// Write data
dst.put_slice(&item.data);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tcp_request_encode_decode() {
let msg = TcpRequestMessage::new(
"test.endpoint".to_string(),
Bytes::from(vec![1, 2, 3, 4, 5]),
);
let encoded = msg.encode().unwrap();
let decoded = TcpRequestMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_empty_payload() {
let msg = TcpRequestMessage::new("test".to_string(), Bytes::new());
let encoded = msg.encode().unwrap();
let decoded = TcpRequestMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_large_payload() {
let payload = Bytes::from(vec![42u8; 1024 * 1024]); // 1MB
let msg = TcpRequestMessage::new("large".to_string(), payload);
let encoded = msg.encode().unwrap();
let decoded = TcpRequestMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_decode_truncated() {
let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
let encoded = msg.encode().unwrap();
// Truncate the encoded message
let truncated = encoded.slice(..encoded.len() - 2);
let result = TcpRequestMessage::decode(&truncated);
assert!(result.is_err());
}
#[test]
fn test_tcp_response_encode_decode() {
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let encoded = msg.encode().unwrap();
let decoded = TcpResponseMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_response_empty() {
let msg = TcpResponseMessage::empty();
let encoded = msg.encode().unwrap();
let decoded = TcpResponseMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
assert_eq!(decoded.data.len(), 0);
}
#[test]
fn test_tcp_response_decode_truncated() {
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let encoded = msg.encode().unwrap();
// Truncate the encoded message
let truncated = encoded.slice(..3);
let result = TcpResponseMessage::decode(&truncated);
assert!(result.is_err());
}
#[test]
fn test_tcp_request_unicode_endpoint() {
let msg = TcpRequestMessage::new("тест.端点".to_string(), Bytes::from(vec![1, 2, 3]));
let encoded = msg.encode().unwrap();
let decoded = TcpRequestMessage::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_codec() {
use tokio_util::codec::{Decoder, Encoder};
let msg = TcpRequestMessage::new(
"test.endpoint".to_string(),
Bytes::from(vec![1, 2, 3, 4, 5]),
);
let mut codec = TcpRequestCodec::new(None);
let mut buf = BytesMut::new();
// Encode
codec.encode(msg.clone(), &mut buf).unwrap();
// Decode
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_codec_partial() {
use tokio_util::codec::Decoder;
let msg = TcpRequestMessage::new(
"test.endpoint".to_string(),
Bytes::from(vec![1, 2, 3, 4, 5]),
);
let encoded = msg.encode().unwrap();
let mut codec = TcpRequestCodec::new(None);
// Feed partial data
let mut buf = BytesMut::from(&encoded[..5]);
assert!(codec.decode(&mut buf).unwrap().is_none());
// Feed rest of data
buf.extend_from_slice(&encoded[5..]);
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_request_codec_max_size() {
use tokio_util::codec::Encoder;
let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]));
let mut codec = TcpRequestCodec::new(Some(10)); // Too small
let mut buf = BytesMut::new();
let result = codec.encode(msg, &mut buf);
assert!(result.is_err());
}
#[test]
fn test_tcp_response_codec() {
use tokio_util::codec::{Decoder, Encoder};
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let mut codec = TcpResponseCodec::new(None);
let mut buf = BytesMut::new();
// Encode
codec.encode(msg.clone(), &mut buf).unwrap();
// Decode
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_response_codec_partial() {
use tokio_util::codec::Decoder;
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let encoded = msg.encode().unwrap();
let mut codec = TcpResponseCodec::new(None);
// Feed partial data
let mut buf = BytesMut::from(&encoded[..3]);
assert!(codec.decode(&mut buf).unwrap().is_none());
// Feed rest of data
buf.extend_from_slice(&encoded[3..]);
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_tcp_response_codec_max_size() {
use tokio_util::codec::Encoder;
let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5]));
let mut codec = TcpResponseCodec::new(Some(5)); // Too small
let mut buf = BytesMut::new();
let result = codec.encode(msg, &mut buf);
assert!(result.is_err());
}
/// Demonstrates how framed codec enables testability without actual TCP connections
#[tokio::test]
async fn test_framed_codec_integration() {
use futures::{SinkExt, StreamExt};
use std::io::Cursor;
use tokio_util::codec::{FramedRead, FramedWrite};
// Simulate a duplex connection using in-memory buffer
let mut buffer = Vec::new();
// Writer side: encode requests
{
let cursor = Cursor::new(&mut buffer);
let mut writer = FramedWrite::new(cursor, TcpRequestCodec::new(None));
let msg1 = TcpRequestMessage::new("endpoint1".to_string(), Bytes::from("data1"));
let msg2 = TcpRequestMessage::new("endpoint2".to_string(), Bytes::from("data2"));
writer.send(msg1).await.unwrap();
writer.send(msg2).await.unwrap();
}
// Reader side: decode requests
{
let cursor = Cursor::new(&buffer[..]);
let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
let decoded1 = reader.next().await.unwrap().unwrap();
assert_eq!(decoded1.endpoint_path, "endpoint1");
assert_eq!(decoded1.payload, Bytes::from("data1"));
let decoded2 = reader.next().await.unwrap().unwrap();
assert_eq!(decoded2.endpoint_path, "endpoint2");
assert_eq!(decoded2.payload, Bytes::from("data2"));
}
}
/// Demonstrates testing partial message handling
#[tokio::test]
async fn test_framed_codec_partial_messages() {
use futures::StreamExt;
use std::io::Cursor;
use tokio_util::codec::FramedRead;
// Create a message and encode it
let msg = TcpRequestMessage::new("test".to_string(), Bytes::from("hello"));
let encoded = msg.encode().unwrap();
// Split the encoded message into chunks
let chunk1 = &encoded[..5];
let chunk2 = &encoded[5..];
// Create a buffer that simulates receiving data in chunks
let mut full_buffer = Vec::new();
full_buffer.extend_from_slice(chunk1);
// Reader can't decode yet (partial data)
{
let cursor = Cursor::new(&full_buffer[..]);
let _reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
// In real async, this would return Ok(None) and wait for more data
// For Cursor, it returns None at EOF
}
// Add the rest of the data
full_buffer.extend_from_slice(chunk2);
// Now decoding succeeds
{
let cursor = Cursor::new(&full_buffer[..]);
let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None));
let decoded = reader.next().await.unwrap().unwrap();
assert_eq!(decoded.endpoint_path, "test");
assert_eq!(decoded.payload, Bytes::from("hello"));
}
}
}
......@@ -2,6 +2,12 @@
// SPDX-License-Identifier: Apache-2.0
pub mod addressed_router;
pub mod http_router;
pub mod nats_client;
pub mod push_router;
// Unified request plane interface and implementations
pub mod tcp_client;
pub mod unified_client;
use super::*;
......@@ -3,14 +3,13 @@
use std::sync::Arc;
use super::unified_client::RequestPlaneClient;
use super::*;
use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
use crate::logging::DistributedTraceContext;
use crate::logging::get_distributed_tracing_context;
use crate::logging::inject_otel_context_into_nats_headers;
use crate::logging::inject_trace_headers_into_map;
use crate::pipeline::network::ConnectionInfo;
use crate::pipeline::network::NetworkStreamWrapper;
use crate::pipeline::network::PendingConnections;
use crate::pipeline::network::ResponseService;
use crate::pipeline::network::STREAM_ERR_MSG;
use crate::pipeline::network::StreamOptions;
use crate::pipeline::network::TwoPartCodec;
......@@ -20,8 +19,6 @@ use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
use crate::protocols::maybe_error::MaybeError;
use anyhow::{Error, Result};
use async_nats::client::Client;
use async_nats::{HeaderMap, HeaderValue};
use serde::Deserialize;
use serde::Serialize;
use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
......@@ -59,26 +56,30 @@ impl<T> AddressedRequest<T> {
Self { request, address }
}
fn into_parts(self) -> (T, String) {
pub(crate) fn into_parts(self) -> (T, String) {
(self.request, self.address)
}
}
pub struct AddressedPushRouter {
// todo: generalize with a generic
req_transport: Client,
// Request transport (unified trait object - works with all transports)
req_client: Arc<dyn RequestPlaneClient>,
// todo: generalize with a generic
// Response transport (TCP streaming - unchanged)
resp_transport: Arc<tcp::server::TcpStreamServer>,
}
impl AddressedPushRouter {
/// Create a new router with a request plane client
///
/// This is the unified constructor that works with any transport type.
/// The client is provided as a trait object, hiding the specific implementation.
pub fn new(
req_transport: Client,
req_client: Arc<dyn RequestPlaneClient>,
resp_transport: Arc<tcp::server::TcpStreamServer>,
) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
req_transport,
req_client,
resp_transport,
}))
}
......@@ -154,32 +155,22 @@ where
// TRANSPORT ABSTRACT REQUIRED - END HERE
tracing::trace!(request_id, "enqueueing two-part message to nats");
// Insert Trace Context into Headers
// Enables span to be created in push_endpoint before
// payload is parsed
// Prepare trace headers using the OpenTelemetry injector pattern
// This handles traceparent and tracestate headers according to W3C Trace Context standard
let mut headers = HeaderMap::new();
inject_otel_context_into_nats_headers(&mut headers, None);
// Send request using unified client interface
tracing::trace!(
request_id,
transport = self.req_client.transport_name(),
address = %address,
"Sending request via request plane client"
);
// Add additional custom headers that aren't handled by the OpenTelemetry propagator
if let Some(trace_context) = get_distributed_tracing_context() {
if let Some(x_request_id) = trace_context.x_request_id {
headers.insert("x-request-id", x_request_id);
}
if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
headers.insert("x-dynamo-request-id", x_dynamo_request_id);
}
}
// Prepare trace headers using shared helper
let mut headers = std::collections::HashMap::new();
inject_trace_headers_into_map(&mut headers);
// we might need to add a timeout on this if there is no subscriber to the subject; however, I think nats
// will handle this for us
// Send request (works for all transport types)
let _response = self
.req_transport
.request_with_headers(address.to_string(), headers, buffer)
.req_client
.send_request(address, buffer, headers)
.await?;
tracing::trace!(request_id, "awaiting transport handshake");
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! HTTP/2 client for request plane
use super::unified_client::{Headers, RequestPlaneClient};
use crate::Result;
use async_trait::async_trait;
use bytes::Bytes;
use std::sync::Arc;
use std::time::Duration;
/// Default timeout for HTTP requests (ack only, not full response)
const DEFAULT_HTTP_REQUEST_TIMEOUT_SECS: u64 = 5;
/// HTTP/2 Performance Configuration Constants
const DEFAULT_MAX_FRAME_SIZE: u32 = 1024 * 1024; // 1MB frame size for better throughput
const DEFAULT_MAX_CONCURRENT_STREAMS: u32 = 1000; // Allow more concurrent streams
const DEFAULT_POOL_MAX_IDLE_PER_HOST: usize = 100; // Increased connection pool
const DEFAULT_POOL_IDLE_TIMEOUT_SECS: u64 = 90; // Keep connections alive longer
const DEFAULT_HTTP2_KEEP_ALIVE_INTERVAL_SECS: u64 = 30; // Send pings every 30s
const DEFAULT_HTTP2_KEEP_ALIVE_TIMEOUT_SECS: u64 = 10; // Timeout for ping responses
const DEFAULT_HTTP2_ADAPTIVE_WINDOW: bool = true; // Enable adaptive flow control
/// HTTP/2 Performance Configuration
#[derive(Debug, Clone)]
pub struct Http2Config {
pub max_frame_size: u32,
pub max_concurrent_streams: u32,
pub pool_max_idle_per_host: usize,
pub pool_idle_timeout: Duration,
pub keep_alive_interval: Duration,
pub keep_alive_timeout: Duration,
pub adaptive_window: bool,
pub request_timeout: Duration,
}
impl Default for Http2Config {
fn default() -> Self {
Self {
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
max_concurrent_streams: DEFAULT_MAX_CONCURRENT_STREAMS,
pool_max_idle_per_host: DEFAULT_POOL_MAX_IDLE_PER_HOST,
pool_idle_timeout: Duration::from_secs(DEFAULT_POOL_IDLE_TIMEOUT_SECS),
keep_alive_interval: Duration::from_secs(DEFAULT_HTTP2_KEEP_ALIVE_INTERVAL_SECS),
keep_alive_timeout: Duration::from_secs(DEFAULT_HTTP2_KEEP_ALIVE_TIMEOUT_SECS),
adaptive_window: DEFAULT_HTTP2_ADAPTIVE_WINDOW,
request_timeout: Duration::from_secs(DEFAULT_HTTP_REQUEST_TIMEOUT_SECS),
}
}
}
impl Http2Config {
/// Create configuration from environment variables
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(val) = std::env::var("DYN_HTTP2_MAX_FRAME_SIZE")
&& let Ok(size) = val.parse::<u32>()
{
config.max_frame_size = size;
}
if let Ok(val) = std::env::var("DYN_HTTP2_MAX_CONCURRENT_STREAMS")
&& let Ok(streams) = val.parse::<u32>()
{
config.max_concurrent_streams = streams;
}
if let Ok(val) = std::env::var("DYN_HTTP2_POOL_MAX_IDLE_PER_HOST")
&& let Ok(pool_size) = val.parse::<usize>()
{
config.pool_max_idle_per_host = pool_size;
}
if let Ok(val) = std::env::var("DYN_HTTP2_POOL_IDLE_TIMEOUT_SECS")
&& let Ok(timeout) = val.parse::<u64>()
{
config.pool_idle_timeout = Duration::from_secs(timeout);
}
if let Ok(val) = std::env::var("DYN_HTTP2_KEEP_ALIVE_INTERVAL_SECS")
&& let Ok(interval) = val.parse::<u64>()
{
config.keep_alive_interval = Duration::from_secs(interval);
}
if let Ok(val) = std::env::var("DYN_HTTP2_KEEP_ALIVE_TIMEOUT_SECS")
&& let Ok(timeout) = val.parse::<u64>()
{
config.keep_alive_timeout = Duration::from_secs(timeout);
}
if let Ok(val) = std::env::var("DYN_HTTP2_ADAPTIVE_WINDOW") {
config.adaptive_window = val.parse().unwrap_or(DEFAULT_HTTP2_ADAPTIVE_WINDOW);
}
if let Ok(val) = std::env::var("DYN_HTTP_REQUEST_TIMEOUT")
&& let Ok(timeout) = val.parse::<u64>()
{
config.request_timeout = Duration::from_secs(timeout);
}
config
}
}
/// HTTP/2 request plane client
pub struct HttpRequestClient {
client: reqwest::Client,
config: Http2Config,
}
impl HttpRequestClient {
/// Create a new HTTP request client with HTTP/2 and default configuration
pub fn new() -> Result<Self> {
Self::with_config(Http2Config::default())
}
/// Create a new HTTP request client with custom timeout (legacy method)
/// Uses HTTP/2 with prior knowledge to avoid ALPN negotiation overhead
pub fn with_timeout(timeout: Duration) -> Result<Self> {
let config = Http2Config {
request_timeout: timeout,
..Http2Config::default()
};
Self::with_config(config)
}
/// Create a new HTTP request client with basic configuration
///
/// Note: Advanced HTTP/2 configuration methods may not be available in all versions of reqwest.
/// This implementation uses only the stable, widely-supported configuration options.
pub fn with_config(config: Http2Config) -> Result<Self> {
let builder = reqwest::Client::builder()
.pool_max_idle_per_host(config.pool_max_idle_per_host)
.pool_idle_timeout(config.pool_idle_timeout)
.timeout(config.request_timeout);
// HTTP/2 is automatically negotiated by reqwest when available
let client = builder.build()?;
Ok(Self { client, config })
}
/// Create from environment configuration
pub fn from_env() -> Result<Self> {
Self::with_config(Http2Config::from_env())
}
/// Get the current HTTP/2 configuration
pub fn config(&self) -> &Http2Config {
&self.config
}
}
impl Default for HttpRequestClient {
fn default() -> Self {
Self::new().expect("Failed to create HTTP request client")
}
}
#[async_trait]
impl RequestPlaneClient for HttpRequestClient {
async fn send_request(
&self,
address: String,
payload: Bytes,
headers: Headers,
) -> Result<Bytes> {
let mut req = self
.client
.post(&address)
.header("Content-Type", "application/octet-stream")
.body(payload);
// Add custom headers
for (key, value) in headers {
req = req.header(key, value);
}
let response = req.send().await?;
if !response.status().is_success() {
anyhow::bail!(
"HTTP request failed with status {}: {}",
response.status(),
response.text().await.unwrap_or_default()
);
}
let body = response.bytes().await?;
Ok(body)
}
fn transport_name(&self) -> &'static str {
"http2"
}
fn is_healthy(&self) -> bool {
// HTTP client is stateless and always healthy if created successfully
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{Router, body::Bytes as AxumBytes, extract::State as AxumState, routing::post};
use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex;
#[test]
fn test_http_client_creation() {
let client = HttpRequestClient::new();
assert!(client.is_ok());
}
#[test]
fn test_http_client_with_custom_timeout() {
let client = HttpRequestClient::with_timeout(Duration::from_secs(10));
assert!(client.is_ok());
assert_eq!(
client.unwrap().config.request_timeout,
Duration::from_secs(10)
);
}
#[test]
fn test_http2_config_from_env() {
// Set environment variables
unsafe {
std::env::set_var("DYN_HTTP2_MAX_FRAME_SIZE", "2097152"); // 2MB
std::env::set_var("DYN_HTTP2_MAX_CONCURRENT_STREAMS", "2000");
std::env::set_var("DYN_HTTP2_POOL_MAX_IDLE_PER_HOST", "200");
std::env::set_var("DYN_HTTP2_KEEP_ALIVE_INTERVAL_SECS", "60");
std::env::set_var("DYN_HTTP2_ADAPTIVE_WINDOW", "false");
}
let config = Http2Config::from_env();
assert_eq!(config.max_frame_size, 2097152);
assert_eq!(config.max_concurrent_streams, 2000);
assert_eq!(config.pool_max_idle_per_host, 200);
assert_eq!(config.keep_alive_interval, Duration::from_secs(60));
assert!(!config.adaptive_window);
// Clean up
unsafe {
std::env::remove_var("DYN_HTTP2_MAX_FRAME_SIZE");
std::env::remove_var("DYN_HTTP2_MAX_CONCURRENT_STREAMS");
std::env::remove_var("DYN_HTTP2_POOL_MAX_IDLE_PER_HOST");
std::env::remove_var("DYN_HTTP2_KEEP_ALIVE_INTERVAL_SECS");
std::env::remove_var("DYN_HTTP2_ADAPTIVE_WINDOW");
}
}
#[test]
fn test_http_client_with_custom_config() {
let config = Http2Config {
max_frame_size: 512 * 1024, // 512KB
max_concurrent_streams: 500,
pool_max_idle_per_host: 75,
pool_idle_timeout: Duration::from_secs(60),
keep_alive_interval: Duration::from_secs(45),
keep_alive_timeout: Duration::from_secs(15),
adaptive_window: false,
request_timeout: Duration::from_secs(8),
};
let client = HttpRequestClient::with_config(config.clone());
assert!(client.is_ok());
let client = client.unwrap();
assert_eq!(client.config.max_frame_size, 512 * 1024);
assert_eq!(client.config.max_concurrent_streams, 500);
assert_eq!(client.config.pool_max_idle_per_host, 75);
assert_eq!(client.config.request_timeout, Duration::from_secs(8));
}
#[tokio::test]
async fn test_http_client_send_request_invalid_url() {
let client = HttpRequestClient::new().unwrap();
let result = client
.send_request(
"http://invalid-host-that-does-not-exist:9999/test".to_string(),
Bytes::from("test"),
std::collections::HashMap::new(),
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_http2_client_server_integration() {
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as ConnBuilder;
use hyper_util::service::TowerToHyperService;
// Create a test server that accepts HTTP/2
#[derive(Clone)]
struct TestState {
received: Arc<TokioMutex<Vec<Bytes>>>,
protocol_version: Arc<TokioMutex<Option<String>>>,
}
async fn test_handler(
AxumState(state): AxumState<TestState>,
body: AxumBytes,
) -> &'static str {
state.received.lock().await.push(body);
"OK"
}
let state = TestState {
received: Arc::new(TokioMutex::new(Vec::new())),
protocol_version: Arc::new(TokioMutex::new(None)),
};
let app = Router::new()
.route("/test", post(test_handler))
.with_state(state.clone());
// Bind to a random port
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Start HTTP/2 server
let server_handle = tokio::spawn(async move {
loop {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let app = app.clone();
tokio::spawn(async move {
let conn_builder = ConnBuilder::new(TokioExecutor::new());
let io = TokioIo::new(stream);
let tower_service = app.into_service();
let hyper_service = TowerToHyperService::new(tower_service);
let _ = conn_builder.serve_connection(io, hyper_service).await;
});
}
});
// Give server time to start
tokio::time::sleep(Duration::from_millis(100)).await;
// Create HTTP/2 client with prior knowledge
let client = HttpRequestClient::new().unwrap();
// Send request
let test_data = Bytes::from("test_payload");
let result = client
.send_request(
format!("http://{}/test", addr),
test_data.clone(),
std::collections::HashMap::new(),
)
.await;
// Verify request succeeded
assert!(result.is_ok(), "Request failed: {:?}", result.err());
// Verify server received the data
tokio::time::sleep(Duration::from_millis(100)).await;
let received = state.received.lock().await;
assert_eq!(received.len(), 1);
assert_eq!(received[0], test_data);
// Cleanup
server_handle.abort();
}
#[tokio::test]
async fn test_http2_headers_propagation() {
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as ConnBuilder;
use hyper_util::service::TowerToHyperService;
// Create a test server that captures headers
#[derive(Clone)]
struct HeaderState {
headers: Arc<TokioMutex<Vec<(String, String)>>>,
}
async fn header_handler(
AxumState(state): AxumState<HeaderState>,
headers: axum::http::HeaderMap,
) -> &'static str {
let mut captured = state.headers.lock().await;
for (name, value) in headers.iter() {
if let Ok(val_str) = value.to_str() {
captured.push((name.to_string(), val_str.to_string()));
}
}
"OK"
}
let state = HeaderState {
headers: Arc::new(TokioMutex::new(Vec::new())),
};
let app = Router::new()
.route("/test", post(header_handler))
.with_state(state.clone());
// Bind to a random port
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Start HTTP/2 server
let server_handle = tokio::spawn(async move {
loop {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let app = app.clone();
tokio::spawn(async move {
let conn_builder = ConnBuilder::new(TokioExecutor::new());
let io = TokioIo::new(stream);
let tower_service = app.into_service();
let hyper_service = TowerToHyperService::new(tower_service);
let _ = conn_builder.serve_connection(io, hyper_service).await;
});
}
});
// Give server time to start
tokio::time::sleep(Duration::from_millis(100)).await;
// Create HTTP/2 client
let client = HttpRequestClient::new().unwrap();
// Send request with custom headers
let mut headers = std::collections::HashMap::new();
headers.insert("x-test-header".to_string(), "test-value".to_string());
headers.insert("x-request-id".to_string(), "req-123".to_string());
let result = client
.send_request(
format!("http://{}/test", addr),
Bytes::from("test"),
headers,
)
.await;
// Verify request succeeded
assert!(result.is_ok());
// Verify headers were received
tokio::time::sleep(Duration::from_millis(100)).await;
let received_headers = state.headers.lock().await;
let header_map: std::collections::HashMap<_, _> = received_headers
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
assert!(header_map.contains_key("x-test-header"));
assert_eq!(header_map.get("x-test-header"), Some(&"test-value"));
assert!(header_map.contains_key("x-request-id"));
assert_eq!(header_map.get("x-request-id"), Some(&"req-123"));
// Cleanup
server_handle.abort();
}
#[tokio::test]
async fn test_http2_concurrent_requests() {
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as ConnBuilder;
use hyper_util::service::TowerToHyperService;
use std::sync::atomic::{AtomicU64, Ordering};
// Create a test server that counts requests
#[derive(Clone)]
struct CounterState {
count: Arc<AtomicU64>,
}
async fn counter_handler(AxumState(state): AxumState<CounterState>) -> String {
let count = state.count.fetch_add(1, Ordering::SeqCst);
format!("{}", count)
}
let state = CounterState {
count: Arc::new(AtomicU64::new(0)),
};
let app = Router::new()
.route("/test", post(counter_handler))
.with_state(state.clone());
// Bind to a random port
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Start HTTP/2 server
let server_handle = tokio::spawn(async move {
loop {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let app = app.clone();
tokio::spawn(async move {
let conn_builder = ConnBuilder::new(TokioExecutor::new());
let io = TokioIo::new(stream);
let tower_service = app.into_service();
let hyper_service = TowerToHyperService::new(tower_service);
let _ = conn_builder.serve_connection(io, hyper_service).await;
});
}
});
// Give server time to start
tokio::time::sleep(Duration::from_millis(100)).await;
// Create HTTP/2 client
let client = Arc::new(HttpRequestClient::new().unwrap());
// Send multiple concurrent requests (HTTP/2 multiplexing)
let mut handles = vec![];
for _ in 0..10 {
let client = client.clone();
let handle = tokio::spawn(async move {
client
.send_request(
format!("http://{}/test", addr),
Bytes::from("test"),
std::collections::HashMap::new(),
)
.await
});
handles.push(handle);
}
// Wait for all requests to complete
let mut success_count = 0;
for handle in handles {
if let Ok(Ok(_)) = handle.await {
success_count += 1;
}
}
// Verify all requests succeeded
assert_eq!(success_count, 10);
// Verify server received all requests
assert_eq!(state.count.load(Ordering::SeqCst), 10);
// Cleanup
server_handle.abort();
}
#[tokio::test]
async fn test_http2_performance_benchmark() {
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as ConnBuilder;
use hyper_util::service::TowerToHyperService;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
// Create a test server that measures performance
#[derive(Clone)]
struct PerfState {
request_count: Arc<AtomicU64>,
total_bytes: Arc<AtomicU64>,
}
async fn perf_handler(
AxumState(state): AxumState<PerfState>,
body: AxumBytes,
) -> &'static str {
state.request_count.fetch_add(1, Ordering::Relaxed);
state
.total_bytes
.fetch_add(body.len() as u64, Ordering::Relaxed);
"OK"
}
let state = PerfState {
request_count: Arc::new(AtomicU64::new(0)),
total_bytes: Arc::new(AtomicU64::new(0)),
};
let app = Router::new()
.route("/perf", post(perf_handler))
.with_state(state.clone());
// Bind to a random port
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Start HTTP/2 server
let server_handle = tokio::spawn(async move {
loop {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let app = app.clone();
tokio::spawn(async move {
let conn_builder = ConnBuilder::new(TokioExecutor::new());
let io = TokioIo::new(stream);
let tower_service = app.into_service();
let hyper_service = TowerToHyperService::new(tower_service);
let _ = conn_builder.serve_connection(io, hyper_service).await;
});
}
});
// Give server time to start
tokio::time::sleep(Duration::from_millis(100)).await;
// Create optimized HTTP/2 client
let optimized_config = Http2Config {
max_frame_size: 1024 * 1024, // 1MB frames
max_concurrent_streams: 1000,
pool_max_idle_per_host: 100,
pool_idle_timeout: Duration::from_secs(90),
keep_alive_interval: Duration::from_secs(30),
keep_alive_timeout: Duration::from_secs(10),
adaptive_window: true,
request_timeout: Duration::from_secs(30),
};
let client = Arc::new(HttpRequestClient::with_config(optimized_config).unwrap());
// Performance test: Send many concurrent requests
let num_requests = 100;
let payload_size = 64 * 1024; // 64KB payload
let payload = Bytes::from(vec![0u8; payload_size]);
let start_time = Instant::now();
let mut handles = vec![];
for _ in 0..num_requests {
let client = client.clone();
let payload = payload.clone();
let handle = tokio::spawn(async move {
let headers = std::collections::HashMap::new();
client
.send_request(format!("http://{}/perf", addr), payload, headers)
.await
});
handles.push(handle);
}
// Wait for all requests to complete
let mut successful_requests = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
successful_requests += 1;
}
}
let elapsed = start_time.elapsed();
let requests_per_sec = successful_requests as f64 / elapsed.as_secs_f64();
let throughput_mbps =
(successful_requests * payload_size) as f64 / elapsed.as_secs_f64() / (1024.0 * 1024.0);
println!("Performance Results:");
println!(
" Successful requests: {}/{}",
successful_requests, num_requests
);
println!(" Total time: {:?}", elapsed);
println!(" Requests/sec: {:.2}", requests_per_sec);
println!(" Throughput: {:.2} MB/s", throughput_mbps);
// Verify server received all requests
let server_count = state.request_count.load(Ordering::Relaxed);
let server_bytes = state.total_bytes.load(Ordering::Relaxed);
assert_eq!(server_count, successful_requests as u64);
assert_eq!(server_bytes, (successful_requests * payload_size) as u64);
// Performance assertions (adjust based on your requirements)
assert!(successful_requests >= num_requests * 95 / 100); // At least 95% success rate
assert!(requests_per_sec > 50.0); // At least 50 requests per second
assert!(throughput_mbps > 10.0); // At least 10 MB/s throughput
// Cleanup
server_handle.abort();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NATS Request Plane Client
//!
//! Wraps the NATS client to implement the unified RequestPlaneClient trait,
//! providing a consistent interface across all transport types.
use super::unified_client::{ClientStats, Headers, RequestPlaneClient};
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
/// NATS implementation of RequestPlaneClient
///
/// This client wraps the async_nats::Client and adapts it to the
/// unified RequestPlaneClient interface.
pub struct NatsRequestClient {
client: async_nats::Client,
}
impl NatsRequestClient {
/// Create a new NATS request client
///
/// # Arguments
///
/// * `client` - The underlying NATS client
pub fn new(client: async_nats::Client) -> Self {
Self { client }
}
}
#[async_trait]
impl RequestPlaneClient for NatsRequestClient {
async fn send_request(
&self,
address: String,
payload: Bytes,
headers: Headers,
) -> Result<Bytes> {
// Convert generic headers to NATS headers
let mut nats_headers = async_nats::HeaderMap::new();
for (key, value) in headers {
nats_headers.insert(key.as_str(), value.as_str());
}
// Send request with headers
let response = self
.client
.request_with_headers(address, nats_headers, payload)
.await
.map_err(|e| anyhow::anyhow!("NATS request failed: {}", e))?;
Ok(response.payload)
}
fn transport_name(&self) -> &'static str {
"nats"
}
fn is_healthy(&self) -> bool {
// Check if NATS client is connected
// NATS client doesn't expose connection state directly, assume healthy
true
}
fn stats(&self) -> ClientStats {
// NATS client doesn't expose detailed stats
// Return basic health indicator
ClientStats {
requests_sent: 0,
responses_received: 0,
errors: 0,
bytes_sent: 0,
bytes_received: 0,
active_connections: if self.is_healthy() { 1 } else { 0 },
idle_connections: 0,
avg_latency_us: 0,
}
}
async fn close(&self) -> Result<()> {
// NATS client doesn't have an explicit close method
// Connection is managed by the client lifecycle
Ok(())
}
}
......@@ -89,13 +89,17 @@ impl RouterMode {
}
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
let Some(nats_client) = endpoint.drt().nats_client() else {
anyhow::bail!("Missing NATS. Please ensure it is running and accessible.");
};
AddressedPushRouter::new(
nats_client.client().clone(),
endpoint.drt().tcp_server().await?,
)
// Get network manager and create client (no mode checks!)
let manager = endpoint.drt().network_manager().await?;
let req_client = manager.create_client()?;
let resp_transport = endpoint.drt().tcp_server().await?;
tracing::debug!(
transport = req_client.transport_name(),
"Creating AddressedPushRouter with request plane client"
);
AddressedPushRouter::new(req_client, resp_transport)
}
impl<T, U> PushRouter<T, U>
......@@ -224,8 +228,48 @@ where
}
}
let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
// Get the address based on discovered transport type
let address = {
use crate::component::TransportType;
// Get the instance and use its actual transport type
let instances = self.client.instances();
let instance = instances
.iter()
.find(|i| i.instance_id == instance_id)
.ok_or_else(|| {
anyhow::anyhow!("Instance {} not found in available instances", instance_id)
})?;
match &instance.transport {
TransportType::Http(http_endpoint) => {
tracing::debug!(
instance_id = instance_id,
http_endpoint = %http_endpoint,
"Using HTTP transport for instance"
);
http_endpoint.clone()
}
TransportType::Tcp(tcp_endpoint) => {
tracing::debug!(
instance_id = instance_id,
tcp_endpoint = %tcp_endpoint,
"Using TCP transport for instance"
);
tcp_endpoint.clone()
}
TransportType::Nats(subject) => {
tracing::debug!(
instance_id = instance_id,
subject = %subject,
"Using NATS transport for instance"
);
subject.clone()
}
}
};
let request = request.map(|req| AddressedRequest::new(req, address));
let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
match stream {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! TCP Request Plane Client
//!
use super::unified_client::{ClientStats, Headers, RequestPlaneClient};
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use dashmap::DashMap;
use futures::StreamExt;
use std::io::IoSlice;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::{Mutex, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_util::codec::FramedRead;
/// Default timeout for TCP request acknowledgment
const DEFAULT_TCP_REQUEST_TIMEOUT_SECS: u64 = 5;
/// Default connection pool size per host
const DEFAULT_POOL_SIZE: usize = 100;
/// Buffer size for request channel per connection (backpressure control)
const REQUEST_CHANNEL_BUFFER: usize = 50;
/// Pre-allocated read buffer size (64KB typical message size)
const READ_BUFFER_SIZE: usize = 65536;
/// Default maximum message size for TCP client (32 MB)
/// This is the limit for a SINGLE message. For larger data, split into multiple messages.
const DEFAULT_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
/// Get maximum message size from environment or use default
fn get_max_message_size() -> usize {
std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE)
}
/// TCP request plane configuration
#[derive(Debug, Clone)]
pub struct TcpRequestConfig {
/// Request timeout
pub request_timeout: Duration,
/// Maximum connections per host
pub pool_size: usize,
/// Connect timeout
pub connect_timeout: Duration,
/// Request channel buffer size
pub channel_buffer: usize,
}
impl Default for TcpRequestConfig {
fn default() -> Self {
Self {
request_timeout: Duration::from_secs(DEFAULT_TCP_REQUEST_TIMEOUT_SECS),
pool_size: DEFAULT_POOL_SIZE,
connect_timeout: Duration::from_secs(5),
channel_buffer: REQUEST_CHANNEL_BUFFER,
}
}
}
impl TcpRequestConfig {
/// Create configuration from environment variables
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(val) = std::env::var("DYN_TCP_REQUEST_TIMEOUT")
&& let Ok(timeout) = val.parse::<u64>()
{
config.request_timeout = Duration::from_secs(timeout);
}
if let Ok(val) = std::env::var("DYN_TCP_POOL_SIZE")
&& let Ok(size) = val.parse::<usize>()
{
config.pool_size = size;
}
if let Ok(val) = std::env::var("DYN_TCP_CONNECT_TIMEOUT")
&& let Ok(timeout) = val.parse::<u64>()
{
config.connect_timeout = Duration::from_secs(timeout);
}
if let Ok(val) = std::env::var("DYN_TCP_CHANNEL_BUFFER")
&& let Ok(size) = val.parse::<usize>()
{
config.channel_buffer = size;
}
config
}
}
/// Request to be sent over TCP
/// Pre-encoded on caller's thread for optimal write performance (hot path optimization)
struct TcpRequest {
/// Pre-encoded request data ready to send (zero-copy Bytes)
/// Encoding happens on caller thread to parallelize across multiple request handlers
encoded_data: Bytes,
/// Oneshot channel to send response back to caller
response_tx: oneshot::Sender<Result<Bytes>>,
}
/// TCP connection with split read/write tasks
///
/// Design: One writer task + one reader task per connection
/// - Writer task receives pre-encoded requests and writes directly (hot path optimized)
/// - Reader task uses framed codec for robust protocol handling
/// - FIFO ordering ensures request/response correlation without explicit IDs
///
/// Performance: Hybrid approach optimizes each path independently:
/// - Write path: Pre-encode on caller thread → direct write (minimal overhead, parallel encoding)
/// - Read path: Framed codec handles partial reads and protocol complexity automatically
struct TcpConnection {
addr: SocketAddr,
/// Channel to send requests to the writer task
request_tx: mpsc::Sender<TcpRequest>,
/// Writer task handle for cleanup
writer_handle: Arc<JoinHandle<()>>,
/// Reader task handle for cleanup
reader_handle: Arc<JoinHandle<()>>,
/// Health status (false if tasks have failed)
healthy: Arc<AtomicBool>,
}
impl TcpConnection {
/// Create a new connection with split read/write tasks
async fn connect(addr: SocketAddr, timeout: Duration, channel_buffer: usize) -> Result<Self> {
let stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
.await
.map_err(|_| anyhow::anyhow!("TCP connect timeout to {}", addr))??;
// Configure socket for lower latency
Self::configure_socket(&stream)?;
let (read_half, write_half) = tokio::io::split(stream);
// Channel for writer task to receive requests
let (request_tx, request_rx) = mpsc::channel::<TcpRequest>(channel_buffer);
// Channel for writer to forward response channels to reader (FIFO correlation)
let (response_tx_channel, response_rx_channel) =
mpsc::unbounded_channel::<oneshot::Sender<Result<Bytes>>>();
let healthy = Arc::new(AtomicBool::new(true));
// Spawn writer task
let writer_handle = {
let healthy = healthy.clone();
tokio::spawn(async move {
if let Err(e) = Self::writer_task(write_half, request_rx, response_tx_channel).await
{
tracing::debug!("Writer task failed for {}: {}", addr, e);
healthy.store(false, Ordering::Relaxed);
}
})
};
// Spawn reader task
let reader_handle = {
let healthy = healthy.clone();
tokio::spawn(async move {
if let Err(e) = Self::reader_task(read_half, response_rx_channel).await {
tracing::debug!("Reader task failed for {}: {}", addr, e);
healthy.store(false, Ordering::Relaxed);
}
})
};
Ok(Self {
addr,
request_tx,
writer_handle: Arc::new(writer_handle),
reader_handle: Arc::new(reader_handle),
healthy,
})
}
/// Configure socket for ultra-low latency based on dyn-transports patterns
fn configure_socket(stream: &TcpStream) -> Result<()> {
use socket2::{SockRef, Socket};
let sock_ref = SockRef::from(stream);
// TCP_NODELAY - disable Nagle's algorithm for immediate send
sock_ref.set_nodelay(true)?;
// Increase socket buffer sizes for better throughput under load
sock_ref.set_recv_buffer_size(2 * 1024 * 1024)?; // 2MB
sock_ref.set_send_buffer_size(2 * 1024 * 1024)?; // 2MB
// Advanced Linux optimizations for ultra-low latency (optional feature)
#[cfg(feature = "tcp-low-latency")]
{
use std::os::unix::io::AsRawFd;
unsafe {
let fd = stream.as_raw_fd();
// TCP_QUICKACK - minimize ACK delay
let quickack: libc::c_int = 1;
libc::setsockopt(
fd,
libc::SOL_TCP,
libc::TCP_QUICKACK,
&quickack as *const _ as *const libc::c_void,
std::mem::size_of_val(&quickack) as libc::socklen_t,
);
// SO_BUSY_POLL - enable busy polling for lower latency (50 microseconds)
let busy_poll: libc::c_int = 50;
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_BUSY_POLL,
&busy_poll as *const _ as *const libc::c_void,
std::mem::size_of_val(&busy_poll) as libc::socklen_t,
);
}
tracing::debug!("TCP low-latency optimizations enabled (TCP_QUICKACK, SO_BUSY_POLL)");
}
Ok(())
}
/// Writer task: receives pre-encoded requests and writes directly to socket
///
/// Performance optimization: Pre-encoding happens on caller's thread to enable
/// parallel encoding across multiple request handlers, while this task focuses
/// on sequential socket writes with minimal overhead.
async fn writer_task(
mut write_half: tokio::io::WriteHalf<TcpStream>,
mut request_rx: mpsc::Receiver<TcpRequest>,
response_tx_channel: mpsc::UnboundedSender<oneshot::Sender<Result<Bytes>>>,
) -> Result<()> {
while let Some(req) = request_rx.recv().await {
// Direct write of pre-encoded data (hot path)
// With TCP_NODELAY, no need for explicit flush()
match write_half.write_all(&req.encoded_data).await {
Ok(()) => {
// Forward response channel to reader task (FIFO ordering)
if response_tx_channel.send(req.response_tx).is_err() {
tracing::debug!("Reader task closed, stopping writer");
break;
}
}
Err(e) => {
// Write failed, notify caller and stop
let err_msg = format!("Write failed: {}", e);
let _ = req.response_tx.send(Err(anyhow::anyhow!("{}", err_msg)));
return Err(anyhow::anyhow!("{}", err_msg));
}
}
}
Ok(())
}
/// Reader task: reads responses using framed codec and sends them back via oneshot channels
/// Protocol framing handled automatically via TcpResponseCodec
async fn reader_task(
read_half: tokio::io::ReadHalf<TcpStream>,
mut response_rx_channel: mpsc::UnboundedReceiver<oneshot::Sender<Result<Bytes>>>,
) -> Result<()> {
use crate::pipeline::network::codec::TcpResponseCodec;
let max_message_size = get_max_message_size();
let codec = TcpResponseCodec::new(Some(max_message_size));
let mut framed = FramedRead::new(read_half, codec);
while let Some(response_tx) = response_rx_channel.recv().await {
// Read the next response message from the framed stream
// The codec handles all protocol framing and size checks automatically
match framed.next().await {
Some(Ok(response_msg)) => {
let _ = response_tx.send(Ok(response_msg.data));
}
Some(Err(e)) => {
let err = anyhow::anyhow!("Failed to decode response: {}", e);
let _ = response_tx.send(Err(err));
return Err(anyhow::anyhow!("Failed to decode response"));
}
None => {
let err = anyhow::anyhow!("Connection closed by peer");
let _ = response_tx.send(Err(err));
return Err(anyhow::anyhow!("Connection closed"));
}
}
}
Ok(())
}
/// Send a request and wait for response
///
/// Performance: Encoding happens on caller's thread (hot path optimization)
/// to enable parallel encoding across multiple request handlers. The writer
/// task then performs sequential writes with minimal overhead.
async fn send_request(&self, payload: Bytes, headers: &Headers) -> Result<Bytes> {
use crate::pipeline::network::codec::TcpRequestMessage;
// Check health before sending
if !self.healthy.load(Ordering::Relaxed) {
anyhow::bail!("Connection unhealthy (tasks failed)");
}
// Extract endpoint path from headers (required for routing)
let endpoint_path = headers
.get("x-endpoint-path")
.ok_or_else(|| anyhow::anyhow!("Missing x-endpoint-path header for TCP request"))?
.to_string();
// Encode request on caller's thread (hot path optimization)
// This allows multiple concurrent callers to encode in parallel
// rather than serializing through the writer task
let request_msg = TcpRequestMessage::new(endpoint_path, payload);
let encoded_data = request_msg.encode()?;
// Create response channel
let (response_tx, response_rx) = oneshot::channel();
// Send to writer task (bounded channel provides backpressure)
let req = TcpRequest {
encoded_data,
response_tx,
};
self.request_tx
.send(req)
.await
.map_err(|_| anyhow::anyhow!("Writer task closed"))?;
// Wait for response from reader task
response_rx
.await
.map_err(|_| anyhow::anyhow!("Reader task closed"))?
}
/// Check if connection is healthy
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Relaxed)
}
}
/// Connection pool with health checking for TCP connections
struct TcpConnectionPool {
pools: DashMap<SocketAddr, Arc<Mutex<Vec<TcpConnection>>>>,
config: TcpRequestConfig,
}
impl TcpConnectionPool {
fn new(config: TcpRequestConfig) -> Self {
Self {
pools: DashMap::new(),
config,
}
}
/// Get a connection from the pool or create a new one
/// Automatically filters out unhealthy connections
async fn get_connection(&self, addr: SocketAddr) -> Result<TcpConnection> {
// Try to get from pool (lock-free read with DashMap)
if let Some(pool) = self.pools.get(&addr) {
let mut pool = pool.lock().await;
// Try to get a healthy connection, discard unhealthy ones
while let Some(conn) = pool.pop() {
if conn.is_healthy() {
return Ok(conn);
} else {
tracing::debug!("Discarding unhealthy connection for {}", addr);
// Connection will be dropped here, cleaning up tasks
}
}
}
// Create new connection with configured channel buffer
tracing::debug!("Creating new TCP connection to {}", addr);
TcpConnection::connect(
addr,
self.config.connect_timeout,
self.config.channel_buffer,
)
.await
}
/// Return a connection to the pool if it's healthy and there's space
async fn return_connection(&self, conn: TcpConnection) {
// Only return healthy connections
if !conn.is_healthy() {
tracing::debug!("Not returning unhealthy connection to pool");
return;
}
let addr = conn.addr;
// Get or create pool for this address (lock-free with DashMap)
let pool = self
.pools
.entry(addr)
.or_insert_with(|| Arc::new(Mutex::new(Vec::new())))
.clone();
let mut pool = pool.lock().await;
if pool.len() < self.config.pool_size {
pool.push(conn);
} else {
tracing::debug!("Connection pool full for {}, dropping connection", addr);
// Otherwise drop the connection (tasks will be cleaned up)
}
}
}
/// TCP request plane client
pub struct TcpRequestClient {
pool: Arc<TcpConnectionPool>,
config: TcpRequestConfig,
stats: Arc<TcpClientStats>,
}
struct TcpClientStats {
requests_sent: AtomicU64,
responses_received: AtomicU64,
errors: AtomicU64,
bytes_sent: AtomicU64,
bytes_received: AtomicU64,
}
impl TcpRequestClient {
/// Create a new TCP request client with default configuration
pub fn new() -> Result<Self> {
Self::with_config(TcpRequestConfig::default())
}
/// Create a new TCP request client with custom configuration
pub fn with_config(config: TcpRequestConfig) -> Result<Self> {
Ok(Self {
pool: Arc::new(TcpConnectionPool::new(config.clone())),
config,
stats: Arc::new(TcpClientStats {
requests_sent: AtomicU64::new(0),
responses_received: AtomicU64::new(0),
errors: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
}),
})
}
/// Create from environment configuration
pub fn from_env() -> Result<Self> {
Self::with_config(TcpRequestConfig::from_env())
}
/// Parse TCP address from string
/// Supports formats: "host:port" or "tcp://host:port" or "host:port/endpoint_name"
/// Returns (SocketAddr, Option<endpoint_name>)
fn parse_address(address: &str) -> Result<(SocketAddr, Option<String>)> {
let addr_str = if let Some(stripped) = address.strip_prefix("tcp://") {
stripped
} else {
address
};
// Check if endpoint name is included (format: host:port/endpoint_name)
if let Some((socket_part, endpoint_name)) = addr_str.split_once('/') {
let socket_addr = socket_part
.parse::<SocketAddr>()
.map_err(|e| anyhow::anyhow!("Invalid TCP address '{}': {}", address, e))?;
Ok((socket_addr, Some(endpoint_name.to_string())))
} else {
let socket_addr = addr_str
.parse::<SocketAddr>()
.map_err(|e| anyhow::anyhow!("Invalid TCP address '{}': {}", address, e))?;
Ok((socket_addr, None))
}
}
}
impl Default for TcpRequestClient {
fn default() -> Self {
Self::new().expect("Failed to create TCP request client")
}
}
#[async_trait]
impl RequestPlaneClient for TcpRequestClient {
async fn send_request(
&self,
address: String,
payload: Bytes,
mut headers: Headers,
) -> Result<Bytes> {
tracing::debug!("TCP client sending request to address: {}", address);
self.stats.requests_sent.fetch_add(1, Ordering::Relaxed);
self.stats
.bytes_sent
.fetch_add(payload.len() as u64, Ordering::Relaxed);
let (addr, endpoint_name) = Self::parse_address(&address)?;
// Add endpoint path to headers if present in address
if let Some(endpoint_name) = endpoint_name {
headers.insert("x-endpoint-path".to_string(), endpoint_name.clone());
}
// Get connection from pool (automatically filters unhealthy connections)
let conn = self.pool.get_connection(addr).await?;
// Send request with timeout
// Note: The connection's send_request now handles all the async I/O via tasks
let result = tokio::time::timeout(
self.config.request_timeout,
conn.send_request(payload, &headers),
)
.await;
match result {
Ok(Ok(response)) => {
self.stats
.responses_received
.fetch_add(1, Ordering::Relaxed);
self.stats
.bytes_received
.fetch_add(response.len() as u64, Ordering::Relaxed);
// Return connection to pool (health check happens inside)
self.pool.return_connection(conn).await;
Ok(response)
}
Ok(Err(e)) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed);
tracing::warn!("TCP request failed to {}: {}", addr, e);
// Don't return unhealthy connection to pool, let it drop
Err(e)
}
Err(_) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed);
tracing::warn!("TCP request timeout to {}", addr);
// Don't return timed-out connection to pool
Err(anyhow::anyhow!("TCP request timeout to {}", addr))
}
}
}
fn transport_name(&self) -> &'static str {
"tcp"
}
fn is_healthy(&self) -> bool {
true // TCP client is always healthy if it was created successfully
}
fn stats(&self) -> ClientStats {
ClientStats {
requests_sent: self.stats.requests_sent.load(Ordering::Relaxed),
responses_received: self.stats.responses_received.load(Ordering::Relaxed),
errors: self.stats.errors.load(Ordering::Relaxed),
bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
active_connections: 0, // Could track this if needed
idle_connections: 0,
avg_latency_us: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use tokio::net::TcpListener;
#[test]
fn test_tcp_config_default() {
let config = TcpRequestConfig::default();
assert_eq!(config.pool_size, DEFAULT_POOL_SIZE);
assert_eq!(
config.request_timeout,
Duration::from_secs(DEFAULT_TCP_REQUEST_TIMEOUT_SECS)
);
assert_eq!(config.channel_buffer, REQUEST_CHANNEL_BUFFER);
}
#[test]
fn test_tcp_config_from_env() {
unsafe {
std::env::set_var("DYN_TCP_REQUEST_TIMEOUT", "10");
std::env::set_var("DYN_TCP_POOL_SIZE", "50");
std::env::set_var("DYN_TCP_CONNECT_TIMEOUT", "3");
std::env::set_var("DYN_TCP_CHANNEL_BUFFER", "100");
}
let config = TcpRequestConfig::from_env();
assert_eq!(config.request_timeout, Duration::from_secs(10));
assert_eq!(config.pool_size, 50);
assert_eq!(config.connect_timeout, Duration::from_secs(3));
assert_eq!(config.channel_buffer, 100);
// Clean up env vars
unsafe {
std::env::remove_var("DYN_TCP_REQUEST_TIMEOUT");
std::env::remove_var("DYN_TCP_POOL_SIZE");
std::env::remove_var("DYN_TCP_CONNECT_TIMEOUT");
std::env::remove_var("DYN_TCP_CHANNEL_BUFFER");
}
}
#[test]
fn test_parse_address() {
let (addr1, _) = TcpRequestClient::parse_address("127.0.0.1:8080").unwrap();
assert_eq!(addr1.port(), 8080);
let (addr2, _) = TcpRequestClient::parse_address("tcp://127.0.0.1:9090").unwrap();
assert_eq!(addr2.port(), 9090);
let (addr3, endpoint) =
TcpRequestClient::parse_address("127.0.0.1:8080/test_endpoint").unwrap();
assert_eq!(addr3.port(), 8080);
assert_eq!(endpoint, Some("test_endpoint".to_string()));
assert!(TcpRequestClient::parse_address("invalid").is_err());
}
#[test]
fn test_tcp_client_creation() {
let client = TcpRequestClient::new();
assert!(client.is_ok());
let client = client.unwrap();
assert_eq!(client.transport_name(), "tcp");
assert!(client.is_healthy());
}
#[tokio::test]
async fn test_connection_health_check() {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage};
// Start a mock TCP server
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Spawn server that responds to requests
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
// Read request
let mut len_buf = [0u8; 2];
read_half.read_exact(&mut len_buf).await.unwrap();
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
read_half.read_exact(&mut path_buf).await.unwrap();
let mut len_buf = [0u8; 4];
read_half.read_exact(&mut len_buf).await.unwrap();
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
read_half.read_exact(&mut payload_buf).await.unwrap();
// Send response
let response = TcpResponseMessage::new(Bytes::from_static(b"pong"));
let encoded = response.encode().unwrap();
write_half.write_all(&encoded).await.unwrap();
});
// Create connection
let conn = TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap();
assert!(conn.is_healthy(), "New connection should be healthy");
// Send a request
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let result = conn.send_request(Bytes::from("ping"), &headers).await;
assert!(result.is_ok(), "Request should succeed");
assert_eq!(result.unwrap(), Bytes::from("pong"));
assert!(
conn.is_healthy(),
"Connection should remain healthy after successful request"
);
}
#[tokio::test]
async fn test_concurrent_requests_single_connection() {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage};
// Start a mock TCP server that handles multiple requests
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_clone = request_count.clone();
// Spawn server that responds to multiple requests
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
// Handle 5 requests
for _ in 0..5 {
// Read request
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
if read_half.read_exact(&mut path_buf).await.is_err() {
break;
}
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
if read_half.read_exact(&mut payload_buf).await.is_err() {
break;
}
request_count_clone.fetch_add(1, Ordering::SeqCst);
// Send response
let response = TcpResponseMessage::new(Bytes::from(payload_buf));
let encoded = response.encode().unwrap();
if write_half.write_all(&encoded).await.is_err() {
break;
}
}
});
// Create connection
let conn = Arc::new(
TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap(),
);
// Send 5 concurrent requests
let mut handles = vec![];
for i in 0..5 {
let conn = conn.clone();
let handle = tokio::spawn(async move {
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let payload = format!("request_{}", i);
conn.send_request(Bytes::from(payload.clone()), &headers)
.await
.map(|response| (payload, response))
});
handles.push(handle);
}
// Wait for all requests to complete
let mut results = vec![];
for handle in handles {
let result = handle.await.unwrap();
assert!(result.is_ok(), "Request should succeed");
results.push(result.unwrap());
}
// Verify all requests got responses
assert_eq!(results.len(), 5);
// Verify server received all requests
assert_eq!(
request_count.load(Ordering::SeqCst),
5,
"Server should have received 5 requests"
);
}
#[tokio::test]
async fn test_connection_pool_reuse() {
use crate::pipeline::network::codec::TcpResponseMessage;
// Start a mock TCP server
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connection_count = Arc::new(AtomicUsize::new(0));
let connection_count_clone = connection_count.clone();
// Spawn server that accepts multiple connections
tokio::spawn(async move {
loop {
let result = listener.accept().await;
if result.is_err() {
break;
}
let (stream, _) = result.unwrap();
connection_count_clone.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
// Read request
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
if read_half.read_exact(&mut path_buf).await.is_err() {
break;
}
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
if read_half.read_exact(&mut payload_buf).await.is_err() {
break;
}
// Send response
let response = TcpResponseMessage::new(Bytes::from_static(b"ok"));
let encoded = response.encode().unwrap();
if write_half.write_all(&encoded).await.is_err() {
break;
}
}
});
}
});
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
// Get connection twice from pool
let conn1 = pool.get_connection(addr).await.unwrap();
pool.return_connection(conn1).await;
// Small delay to ensure connection is returned
tokio::time::sleep(Duration::from_millis(10)).await;
let conn2 = pool.get_connection(addr).await.unwrap();
pool.return_connection(conn2).await;
// Should have created only 1 TCP connection since we reused
assert_eq!(
connection_count.load(Ordering::SeqCst),
1,
"Should reuse connection from pool"
);
}
#[tokio::test]
async fn test_unhealthy_connection_filtered() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Server that immediately closes connections
tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
drop(stream); // Immediately close
}
});
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(1),
connect_timeout: Duration::from_secs(1),
pool_size: 2,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config.clone());
// Try to get a connection - it will become unhealthy quickly
let result =
TcpConnection::connect(addr, config.connect_timeout, config.channel_buffer).await;
if let Ok(conn) = result {
// Mark as unhealthy by trying to use it
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let _ = conn.send_request(Bytes::from("test"), &headers).await;
// Return to pool
pool.return_connection(conn).await;
// Try to get from pool again - should get a new connection attempt
let result2 = pool.get_connection(addr).await;
// This might fail or succeed depending on timing, but should not panic
let _ = result2;
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Unified Request Plane Client Interface
//!
//! This module defines a transport-agnostic interface for sending requests
//! in the request plane. All transport implementations (TCP, HTTP, NATS)
//! implement this trait to provide a consistent interface for the egress router.
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use std::collections::HashMap;
/// Type alias for request headers
pub type Headers = HashMap<String, String>;
/// Unified interface for request plane clients
///
/// This trait abstracts over different transport mechanisms (TCP, HTTP, NATS)
/// providing a consistent interface for sending requests and receiving acknowledgments.
///
/// # Design Principles
///
/// 1. **Transport Agnostic**: Implementations can be swapped without changing router logic
/// 2. **Async by Default**: All operations are async to support high concurrency
/// 3. **Headers Support**: All transports must support custom headers for tracing, etc.
/// 4. **Health Checks**: Implementations should provide connection health information
/// 5. **Error Handling**: All errors are wrapped in anyhow::Result for flexibility
///
/// # Example
///
/// ```ignore
/// use dynamo_runtime::pipeline::network::egress::RequestPlaneClient;
///
/// async fn send_request(client: &dyn RequestPlaneClient) -> Result<()> {
/// let mut headers = HashMap::new();
/// headers.insert("x-request-id".to_string(), "123".to_string());
///
/// let response = client.send_request(
/// "service-endpoint".to_string(),
/// Bytes::from("payload"),
/// headers,
/// ).await?;
///
/// Ok(())
/// }
/// ```
#[async_trait]
pub trait RequestPlaneClient: Send + Sync {
/// Send a request to a specific address and wait for acknowledgment
///
/// # Arguments
///
/// * `address` - Transport-specific address:
/// - HTTP: `http://host:port/path`
/// - TCP: `host:port` or `tcp://host:port`
/// - NATS: `subject.name`
/// * `payload` - Request payload (encoded as bytes)
/// * `headers` - Custom headers for tracing, authentication, etc.
///
/// # Returns
///
/// Returns an acknowledgment response. Note that for streaming responses,
/// the actual response data comes over the TCP response plane, not through
/// this acknowledgment.
///
/// # Errors
///
/// Returns an error if:
/// - Connection to the endpoint fails
/// - Request times out
/// - Transport-specific errors occur (e.g., NATS server unavailable)
async fn send_request(
&self,
address: String,
payload: Bytes,
headers: Headers,
) -> Result<Bytes>;
/// Get the transport name
///
/// Returns a static string identifier for the transport type.
/// Used for logging and debugging.
///
/// # Examples
///
/// - `"tcp"` - Raw TCP transport
/// - `"http"` or `"http2"` - HTTP/2 transport
/// - `"nats"` - NATS messaging
fn transport_name(&self) -> &'static str;
/// Check connection health
///
/// Returns `true` if the client is healthy and ready to send requests.
/// This is a lightweight check that doesn't perform actual network I/O.
///
/// Implementations should return `false` if:
/// - Connection pool is exhausted
/// - Underlying transport is disconnected
/// - Client has been explicitly closed
fn is_healthy(&self) -> bool;
/// Get client statistics (optional)
///
/// Returns runtime statistics about the client for monitoring and debugging.
/// Default implementation returns empty statistics.
fn stats(&self) -> ClientStats {
ClientStats::default()
}
/// Close the client gracefully (optional)
///
/// Implementations should:
/// - Close all active connections
/// - Wait for in-flight requests to complete (or timeout)
/// - Release all resources
///
/// Default implementation does nothing.
async fn close(&self) -> Result<()> {
Ok(())
}
}
/// Client runtime statistics
///
/// Used for monitoring and debugging transport client performance.
#[derive(Debug, Clone, Default)]
pub struct ClientStats {
/// Total number of requests sent
pub requests_sent: u64,
/// Total number of successful responses
pub responses_received: u64,
/// Total number of errors
pub errors: u64,
/// Total bytes sent
pub bytes_sent: u64,
/// Total bytes received
pub bytes_received: u64,
/// Number of active connections (for connection-pooled transports)
pub active_connections: usize,
/// Number of idle connections in pool
pub idle_connections: usize,
/// Average request latency in microseconds (0 if not available)
pub avg_latency_us: u64,
}
impl ClientStats {
/// Create new empty statistics
pub fn new() -> Self {
Self::default()
}
/// Check if statistics are available (non-zero)
pub fn is_available(&self) -> bool {
self.requests_sent > 0 || self.active_connections > 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_stats_default() {
let stats = ClientStats::default();
assert_eq!(stats.requests_sent, 0);
assert_eq!(stats.responses_received, 0);
assert!(!stats.is_available());
}
#[test]
fn test_client_stats_is_available() {
let mut stats = ClientStats::default();
assert!(!stats.is_available());
stats.requests_sent = 1;
assert!(stats.is_available());
let stats2 = ClientStats {
active_connections: 1,
..Default::default()
};
assert!(stats2.is_available());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod http_endpoint;
pub mod nats_server;
pub mod push_endpoint;
pub mod push_handler;
pub mod shared_tcp_endpoint;
pub mod unified_server;
use super::*;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! HTTP endpoint for receiving requests via Axum/HTTP/2
use super::*;
use crate::SystemHealth;
use crate::config::HealthStatus;
use crate::logging::TraceParent;
use anyhow::Result;
use axum::{
Router,
body::Bytes,
extract::{Path, State as AxumState},
http::{HeaderMap, StatusCode},
response::IntoResponse,
routing::post,
};
use dashmap::DashMap;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as Http2Builder;
use hyper_util::service::TowerToHyperService;
use parking_lot::Mutex;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer;
use tracing::Instrument;
/// Default root path for dynamo RPC endpoints
const DEFAULT_RPC_ROOT_PATH: &str = "/v1/rpc";
/// version of crate
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Shared HTTP server that handles multiple endpoints on a single port
pub struct SharedHttpServer {
handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
bind_addr: SocketAddr,
cancellation_token: CancellationToken,
}
/// Handler for a specific endpoint
struct EndpointHandler {
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: Arc<String>,
component_name: Arc<String>,
endpoint_name: Arc<String>,
system_health: Arc<Mutex<SystemHealth>>,
inflight: Arc<AtomicU64>,
notify: Arc<Notify>,
}
impl SharedHttpServer {
pub fn new(bind_addr: SocketAddr, cancellation_token: CancellationToken) -> Arc<Self> {
Arc::new(Self {
handlers: Arc::new(DashMap::new()),
bind_addr,
cancellation_token,
})
}
/// Register an endpoint handler with this server
#[allow(clippy::too_many_arguments)]
pub async fn register_endpoint(
&self,
subject: String,
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
endpoint_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> {
let handler = Arc::new(EndpointHandler {
service_handler,
instance_id,
namespace: Arc::new(namespace),
component_name: Arc::new(component_name),
endpoint_name: Arc::new(endpoint_name.clone()),
system_health: system_health.clone(),
inflight: Arc::new(AtomicU64::new(0)),
notify: Arc::new(Notify::new()),
});
// Set health status
system_health
.lock()
.set_endpoint_health_status(&endpoint_name, HealthStatus::Ready);
let subject_clone = subject.clone();
self.handlers.insert(subject, handler);
tracing::debug!("Registered endpoint handler for subject: {}", subject_clone);
Ok(())
}
/// Unregister an endpoint handler
pub async fn unregister_endpoint(&self, subject: &str, endpoint_name: &str) {
if let Some((_, handler)) = self.handlers.remove(subject) {
handler
.system_health
.lock()
.set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
tracing::debug!("Unregistered endpoint handler for subject: {}", subject);
}
}
/// Start the shared HTTP server
pub async fn start(self: Arc<Self>) -> Result<()> {
let rpc_root_path = std::env::var("DYN_HTTP_RPC_ROOT_PATH")
.unwrap_or_else(|_| DEFAULT_RPC_ROOT_PATH.to_string());
let route_pattern = format!("{}/{{*endpoint}}", rpc_root_path);
let app = Router::new()
.route(&route_pattern, post(handle_shared_request))
.layer(TraceLayer::new_for_http())
.with_state(self.clone());
tracing::info!(
"Starting shared HTTP/2 endpoint server on {} at path {}/:endpoint",
self.bind_addr,
rpc_root_path
);
let listener = tokio::net::TcpListener::bind(&self.bind_addr).await?;
let cancellation_token = self.cancellation_token.clone();
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, _addr)) => {
let app_clone = app.clone();
let cancel_clone = cancellation_token.clone();
tokio::spawn(async move {
// Create HTTP/2 connection builder with prior knowledge
let http2_builder = Http2Builder::new(TokioExecutor::new());
let io = TokioIo::new(stream);
let tower_service = app_clone.into_service();
// Wrap Tower service for Hyper compatibility
let hyper_service = TowerToHyperService::new(tower_service);
tokio::select! {
result = http2_builder.serve_connection(io, hyper_service) => {
if let Err(e) = result {
tracing::debug!("HTTP/2 connection error: {}", e);
}
}
_ = cancel_clone.cancelled() => {
tracing::trace!("Connection cancelled");
}
}
});
}
Err(e) => {
tracing::error!("Failed to accept connection: {}", e);
}
}
}
_ = cancellation_token.cancelled() => {
tracing::info!("SharedHttpServer received cancellation signal, shutting down");
return Ok(());
}
}
}
}
/// Wait for all inflight requests across all endpoints
pub async fn wait_for_inflight(&self) {
for handler in self.handlers.iter() {
while handler.value().inflight.load(Ordering::SeqCst) > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
}
}
/// HTTP handler for the shared server
async fn handle_shared_request(
AxumState(server): AxumState<Arc<SharedHttpServer>>,
Path(endpoint_path): Path<String>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
// Look up the handler for this endpoint (lock-free read with DashMap)
let handler = match server.handlers.get(&endpoint_path) {
Some(h) => h.clone(),
None => {
tracing::warn!("No handler found for endpoint: {}", endpoint_path);
return (StatusCode::NOT_FOUND, "Endpoint not found");
}
};
// Increment inflight counter
handler.inflight.fetch_add(1, Ordering::SeqCst);
// Extract tracing headers
let traceparent = TraceParent::from_axum_headers(&headers);
// Spawn async handler
let service_handler = handler.service_handler.clone();
let inflight = handler.inflight.clone();
let notify = handler.notify.clone();
let namespace = handler.namespace.clone();
let component_name = handler.component_name.clone();
let endpoint_name = handler.endpoint_name.clone();
let instance_id = handler.instance_id;
tokio::spawn(async move {
tracing::trace!(instance_id, "handling new HTTP request");
let result = service_handler
.handle_payload(body)
.instrument(tracing::info_span!(
"handle_payload",
component = component_name.as_ref(),
endpoint = endpoint_name.as_ref(),
namespace = namespace.as_ref(),
instance_id = instance_id,
trace_id = traceparent.trace_id,
parent_id = traceparent.parent_id,
x_request_id = traceparent.x_request_id,
x_dynamo_request_id = traceparent.x_dynamo_request_id,
tracestate = traceparent.tracestate
))
.await;
match result {
Ok(_) => {
tracing::trace!(instance_id, "request handled successfully");
}
Err(e) => {
tracing::warn!("Failed to handle request: {}", e.to_string());
}
}
// Decrease inflight counter
inflight.fetch_sub(1, Ordering::SeqCst);
notify.notify_one();
});
// Return 202 Accepted immediately (like NATS ack)
(StatusCode::ACCEPTED, "")
}
/// Extension trait for TraceParent to support Axum headers
impl TraceParent {
pub fn from_axum_headers(headers: &HeaderMap) -> Self {
let mut traceparent = TraceParent::default();
if let Some(value) = headers.get("traceparent")
&& let Ok(s) = value.to_str()
{
traceparent.trace_id = Some(s.to_string());
}
if let Some(value) = headers.get("tracestate")
&& let Ok(s) = value.to_str()
{
traceparent.tracestate = Some(s.to_string());
}
if let Some(value) = headers.get("x-request-id")
&& let Ok(s) = value.to_str()
{
traceparent.x_request_id = Some(s.to_string());
}
if let Some(value) = headers.get("x-dynamo-request-id")
&& let Ok(s) = value.to_str()
{
traceparent.x_dynamo_request_id = Some(s.to_string());
}
traceparent
}
}
// Implement RequestPlaneServer trait for SharedHttpServer
#[async_trait::async_trait]
impl super::unified_server::RequestPlaneServer for SharedHttpServer {
async fn register_endpoint(
&self,
endpoint_name: String,
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> {
// For HTTP, we use endpoint_name as both the subject (routing key) and endpoint_name
self.register_endpoint(
endpoint_name.clone(),
service_handler,
instance_id,
namespace,
component_name,
endpoint_name,
system_health,
)
.await
}
async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
self.unregister_endpoint(endpoint_name, endpoint_name).await;
Ok(())
}
fn address(&self) -> String {
format!("http://{}:{}", self.bind_addr.ip(), self.bind_addr.port())
}
fn transport_name(&self) -> &'static str {
"http"
}
fn is_healthy(&self) -> bool {
// Server is healthy if it has been created
// TODO: Add more sophisticated health checks (e.g., check if listener is active)
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_traceparent_from_axum_headers() {
let mut headers = HeaderMap::new();
headers.insert("traceparent", "test-trace-id".parse().unwrap());
headers.insert("tracestate", "test-state".parse().unwrap());
headers.insert("x-request-id", "req-123".parse().unwrap());
headers.insert("x-dynamo-request-id", "dyn-456".parse().unwrap());
let traceparent = TraceParent::from_axum_headers(&headers);
assert_eq!(traceparent.trace_id, Some("test-trace-id".to_string()));
assert_eq!(traceparent.tracestate, Some("test-state".to_string()));
assert_eq!(traceparent.x_request_id, Some("req-123".to_string()));
assert_eq!(traceparent.x_dynamo_request_id, Some("dyn-456".to_string()));
}
#[test]
fn test_shared_http_server_creation() {
use std::net::{IpAddr, Ipv4Addr};
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
let token = CancellationToken::new();
let server = SharedHttpServer::new(bind_addr, token);
assert!(server.handlers.is_empty());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NATS Multiplexed Server
//!
//! Provides a multiplexed NATS server that handles multiple endpoints on a single
//! NATS service group. This replaces the per-endpoint PushEndpoint pattern with
//! a unified multiplexed approach consistent with HTTP and TCP servers.
use super::*;
use crate::SystemHealth;
use crate::config::HealthStatus;
use crate::pipeline::network::ingress::push_endpoint::PushEndpoint;
use anyhow::Result;
use async_trait::async_trait;
use dashmap::DashMap;
use parking_lot::Mutex;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
/// Multiplexed NATS server that handles multiple endpoints
///
/// Unlike the previous per-endpoint approach, this server manages multiple
/// endpoints, getting the service group dynamically from the component registry
/// for each endpoint registration.
pub struct NatsMultiplexedServer {
nats_client: async_nats::Client,
component_registry: crate::component::Registry,
handlers: Arc<DashMap<String, EndpointTask>>,
cancellation_token: CancellationToken,
}
struct EndpointTask {
cancel_token: CancellationToken,
_endpoint_name: String,
}
impl NatsMultiplexedServer {
/// Create a new multiplexed NATS server
///
/// # Arguments
///
/// * `nats_client` - NATS client for connection management
/// * `component_registry` - Component registry to get service groups from
/// * `cancellation_token` - Token for graceful shutdown
pub fn new(
nats_client: async_nats::Client,
component_registry: crate::component::Registry,
cancellation_token: CancellationToken,
) -> Arc<Self> {
Arc::new(Self {
nats_client,
component_registry,
handlers: Arc::new(DashMap::new()),
cancellation_token,
})
}
}
#[async_trait]
impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
async fn register_endpoint(
&self,
endpoint_name: String,
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> {
tracing::info!(
endpoint_name = %endpoint_name,
namespace = %namespace,
component = %component_name,
instance_id = instance_id,
"NatsMultiplexedServer::register_endpoint called"
);
// Get the service group from the component registry
// Service name format matches Component::service_name(): "{namespace}_{component}" slugified
use crate::transports::nats::Slug;
let service_name_raw = format!("{}_{}", namespace, component_name);
let service_name = Slug::slugify(&service_name_raw).to_string();
tracing::debug!(
service_name_raw = %service_name_raw,
service_name = %service_name,
"Looking up service group in registry"
);
let registry = self.component_registry.inner.lock().await;
let service_group = registry
.services
.get(&service_name)
.map(|service| service.group(&service_name))
.ok_or_else(|| anyhow::anyhow!("Service '{}' not found in registry", service_name))?;
drop(registry);
tracing::info!("Successfully retrieved service group");
// Construct the full NATS subject with instance ID
// Format: {endpoint_name}-{instance_id_hex}
// This matches Endpoint::name_with_id() and subject_to() format
let endpoint_with_id = format!("{}-{:x}", endpoint_name, instance_id);
// Create NATS service endpoint with the full subject
let service_endpoint = service_group
.endpoint(&endpoint_with_id)
.await
.map_err(|e| {
anyhow::anyhow!(
"Failed to create NATS endpoint '{}': {}",
endpoint_with_id,
e
)
})?;
tracing::info!(
endpoint_name = %endpoint_name,
endpoint_with_id = %endpoint_with_id,
namespace = %namespace,
component = %component_name,
instance_id = instance_id,
"Registering NATS endpoint"
);
// Create cancellation token for this specific endpoint
let endpoint_cancel = CancellationToken::new();
let endpoint_cancel_clone = endpoint_cancel.clone();
// Build the push endpoint
let push_endpoint = PushEndpoint::builder()
.service_handler(service_handler)
.cancellation_token(endpoint_cancel_clone)
.graceful_shutdown(true)
.build()
.map_err(|e| anyhow::anyhow!("Failed to build NATS push endpoint: {}", e))?;
tracing::info!(
endpoint_name = %endpoint_name,
endpoint_with_id = %endpoint_with_id,
"Starting NATS push endpoint listener (blocking)"
);
// Spawn task to handle this endpoint using PushEndpoint
// Note: PushEndpoint::start() is a blocking loop that runs until cancelled
let endpoint_name_clone = endpoint_name.clone();
tokio::spawn(async move {
if let Err(e) = push_endpoint
.start(
service_endpoint,
namespace,
component_name,
endpoint_name_clone.clone(),
instance_id,
system_health,
)
.await
{
tracing::error!(
endpoint_name = %endpoint_name_clone,
error = %e,
"NATS endpoint task failed"
);
} else {
tracing::info!(
endpoint_name = %endpoint_name_clone,
"NATS push endpoint listener completed"
);
}
});
// Give the endpoint a moment to start listening
// This prevents a race condition where discovery registers the endpoint
// before NATS is actually ready to receive requests
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
// Store task info for later cleanup
self.handlers.insert(
endpoint_name.clone(),
EndpointTask {
cancel_token: endpoint_cancel,
_endpoint_name: endpoint_name,
},
);
Ok(())
}
async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
if let Some((_, task)) = self.handlers.remove(endpoint_name) {
tracing::info!(
endpoint_name = %endpoint_name,
"Unregistering NATS endpoint"
);
task.cancel_token.cancel();
}
Ok(())
}
fn address(&self) -> String {
// Return NATS server URL from connection info
// NATS client doesn't expose server info directly, return generic address
"nats://connected".to_string()
}
fn transport_name(&self) -> &'static str {
"nats"
}
fn is_healthy(&self) -> bool {
// Check if NATS client is connected
// NATS client doesn't expose connection state directly, assume healthy
true
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Shared TCP Server with Endpoint Multiplexing
//!
//! Provides a shared TCP server that can handle multiple endpoints on a single port
//! by adding endpoint routing to the TCP wire protocol.
use crate::SystemHealth;
use crate::pipeline::network::PushWorkHandler;
use anyhow::Result;
use bytes::Bytes;
use dashmap::DashMap;
use parking_lot::Mutex;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Notify;
use tokio_util::bytes::BytesMut;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
/// Default maximum message size for TCP server (32 MB)
const DEFAULT_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
/// Get maximum message size from environment or use default
fn get_max_message_size() -> usize {
std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE)
}
/// Shared TCP server that handles multiple endpoints on a single port
pub struct SharedTcpServer {
handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
bind_addr: SocketAddr,
cancellation_token: CancellationToken,
}
struct EndpointHandler {
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
endpoint_name: String,
system_health: Arc<Mutex<SystemHealth>>,
inflight: Arc<AtomicU64>,
notify: Arc<Notify>,
}
impl SharedTcpServer {
pub fn new(bind_addr: SocketAddr, cancellation_token: CancellationToken) -> Arc<Self> {
Arc::new(Self {
handlers: Arc::new(DashMap::new()),
bind_addr,
cancellation_token,
})
}
#[allow(clippy::too_many_arguments)]
pub async fn register_endpoint(
&self,
endpoint_path: String,
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
endpoint_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> {
let handler = Arc::new(EndpointHandler {
service_handler,
instance_id,
namespace,
component_name,
endpoint_name: endpoint_name.clone(),
system_health,
inflight: Arc::new(AtomicU64::new(0)),
notify: Arc::new(Notify::new()),
});
self.handlers.insert(endpoint_path, handler);
tracing::info!(
"Registered endpoint '{}' with shared TCP server on {}",
endpoint_name,
self.bind_addr
);
Ok(())
}
pub async fn unregister_endpoint(&self, endpoint_path: &str, endpoint_name: &str) {
self.handlers.remove(endpoint_path);
tracing::info!(
"Unregistered endpoint '{}' from shared TCP server",
endpoint_name
);
}
pub async fn start(self: Arc<Self>) -> Result<()> {
tracing::info!("Starting shared TCP server on {}", self.bind_addr);
let listener = TcpListener::bind(&self.bind_addr).await?;
let cancellation_token = self.cancellation_token.clone();
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, peer_addr)) => {
tracing::trace!("Accepted TCP connection from {}", peer_addr);
let handlers = self.handlers.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(stream, handlers).await {
tracing::debug!("TCP connection error: {}", e);
}
});
}
Err(e) => {
tracing::error!("Failed to accept TCP connection: {}", e);
}
}
}
_ = cancellation_token.cancelled() => {
tracing::info!("SharedTcpServer received cancellation signal, shutting down");
return Ok(());
}
}
}
}
async fn handle_connection(
stream: TcpStream,
handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
) -> Result<()> {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage};
// Split stream into read and write halves for concurrent operations
let (read_half, write_half) = tokio::io::split(stream);
// Channel for sending responses to the write task (zero-copy Bytes)
let (response_tx, response_rx) = tokio::sync::mpsc::unbounded_channel::<Bytes>();
// Spawn write task
let write_task = tokio::spawn(Self::write_loop(write_half, response_rx));
// Run read task in current context
let read_result = Self::read_loop(read_half, handlers, response_tx).await;
// Write task will end when response_tx is dropped
write_task.await??;
read_result
}
async fn read_loop(
mut read_half: tokio::io::ReadHalf<TcpStream>,
handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
response_tx: tokio::sync::mpsc::UnboundedSender<Bytes>,
) -> Result<()> {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage};
loop {
// Read endpoint path length (2 bytes)
let mut path_len_buf = [0u8; 2];
match read_half.read_exact(&mut path_len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
break;
}
Err(e) => {
return Err(e.into());
}
}
let path_len = u16::from_be_bytes(path_len_buf) as usize;
// Read endpoint path
let mut path_buf = vec![0u8; path_len];
read_half.read_exact(&mut path_buf).await?;
// Read payload length (4 bytes)
let mut len_buf = [0u8; 4];
read_half.read_exact(&mut len_buf).await?;
let payload_len = u32::from_be_bytes(len_buf) as usize;
// Sanity check - enforce maximum message size
let max_message_size = get_max_message_size();
if payload_len > max_message_size {
tracing::warn!(
"Request too large: {} bytes (max: {} bytes), closing connection",
payload_len,
max_message_size
);
// Send error response
let error_response =
TcpResponseMessage::new(Bytes::from_static(b"Request too large"));
if let Ok(encoded) = error_response.encode() {
let _ = response_tx.send(encoded);
}
break;
}
// Read request payload
let mut payload_buf = vec![0u8; payload_len];
read_half.read_exact(&mut payload_buf).await?;
// Reconstruct the full message buffer for decoding using BytesMut
let mut full_msg = BytesMut::with_capacity(2 + path_len + 4 + payload_len);
full_msg.extend_from_slice(&path_len_buf);
full_msg.extend_from_slice(&path_buf);
full_msg.extend_from_slice(&len_buf);
full_msg.extend_from_slice(&payload_buf);
// Decode using codec (zero-copy conversion)
let full_msg_bytes = full_msg.freeze();
let request_msg = match TcpRequestMessage::decode(&full_msg_bytes) {
Ok(msg) => msg,
Err(e) => {
tracing::warn!("Failed to decode TCP request: {}", e);
// Send error response
let error_response =
TcpResponseMessage::new(Bytes::from(format!("Decode error: {}", e)));
if let Ok(encoded) = error_response.encode() {
let _ = response_tx.send(encoded);
}
continue;
}
};
let endpoint_path = request_msg.endpoint_path;
let payload = request_msg.payload;
// Look up handler (lock-free read with DashMap)
let handler = handlers.get(&endpoint_path).map(|h| h.clone());
let handler = match handler {
Some(h) => h,
None => {
tracing::warn!("No handler found for endpoint: {}", endpoint_path);
// Send error response using codec
let error_response = TcpResponseMessage::new(Bytes::from(format!(
"Unknown endpoint: {}",
endpoint_path
)));
if let Ok(encoded) = error_response.encode() {
let _ = response_tx.send(encoded);
}
continue;
}
};
handler.inflight.fetch_add(1, Ordering::SeqCst);
// Send acknowledgment immediately using codec (non-blocking, zero-copy)
let ack_response = TcpResponseMessage::empty();
if let Ok(encoded_ack) = ack_response.encode() {
// Send to write task without blocking reads
if response_tx.send(encoded_ack).is_err() {
tracing::debug!("Write task closed, ending read loop");
break;
}
}
// Process request asynchronously
let service_handler = handler.service_handler.clone();
let inflight = handler.inflight.clone();
let notify = handler.notify.clone();
let instance_id = handler.instance_id;
let namespace = handler.namespace.clone();
let component_name = handler.component_name.clone();
let endpoint_name = handler.endpoint_name.clone();
tokio::spawn(async move {
tracing::trace!(instance_id, "handling TCP request");
let result = service_handler
.handle_payload(payload)
.instrument(tracing::info_span!(
"handle_payload",
component = component_name.as_str(),
endpoint = endpoint_name.as_str(),
namespace = namespace.as_str(),
instance_id = instance_id,
))
.await;
match result {
Ok(_) => {
tracing::trace!(instance_id, "TCP request handled successfully");
}
Err(e) => {
tracing::warn!("Failed to handle TCP request: {}", e);
}
}
inflight.fetch_sub(1, Ordering::SeqCst);
notify.notify_one();
});
}
Ok(())
}
async fn write_loop(
mut write_half: tokio::io::WriteHalf<TcpStream>,
mut response_rx: tokio::sync::mpsc::UnboundedReceiver<Bytes>,
) -> Result<()> {
while let Some(response) = response_rx.recv().await {
write_half.write_all(&response).await?;
write_half.flush().await?;
}
Ok(())
}
}
// Implement RequestPlaneServer trait for SharedTcpServer
#[async_trait::async_trait]
impl super::unified_server::RequestPlaneServer for SharedTcpServer {
async fn register_endpoint(
&self,
endpoint_name: String,
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> {
// For TCP, we use endpoint_name as both the endpoint_path (routing key) and endpoint_name
self.register_endpoint(
endpoint_name.clone(),
service_handler,
instance_id,
namespace,
component_name,
endpoint_name,
system_health,
)
.await
}
async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
self.unregister_endpoint(endpoint_name, endpoint_name).await;
Ok(())
}
fn address(&self) -> String {
format!("tcp://{}:{}", self.bind_addr.ip(), self.bind_addr.port())
}
fn transport_name(&self) -> &'static str {
"tcp"
}
fn is_healthy(&self) -> bool {
// Server is healthy if it has been created
// TODO: Add more sophisticated health checks (e.g., check if listener is active)
true
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Unified Request Plane Server Interface
//!
//! This module defines a transport-agnostic interface for request plane servers.
//! All transport implementations (HTTP, TCP, NATS) implement this trait to provide
//! a consistent interface for endpoint registration and management.
use super::*;
use crate::SystemHealth;
use anyhow::Result;
use async_trait::async_trait;
use parking_lot::Mutex;
use std::sync::Arc;
/// Unified interface for request plane servers
///
/// This trait abstracts over different transport mechanisms (HTTP/2, TCP, NATS)
/// providing a consistent interface for registering endpoints and managing server lifecycle.
///
/// # Design Principles
///
/// 1. **Transport Agnostic**: Implementations can be swapped without changing business logic
/// 2. **Multiplexed**: All servers handle multiple endpoints on a single port/connection
/// 3. **Async by Default**: All operations are async to support high concurrency
/// 4. **Health Monitoring**: Servers provide health status for monitoring
///
/// # Example
///
/// ```ignore
/// use dynamo_runtime::pipeline::network::ingress::RequestPlaneServer;
///
/// async fn register(server: &dyn RequestPlaneServer) -> Result<()> {
/// server.register_endpoint(
/// "generate".to_string(),
/// handler,
/// instance_id,
/// "dynamo".to_string(),
/// "backend".to_string(),
/// system_health,
/// ).await?;
/// Ok(())
/// }
/// ```
#[async_trait]
pub trait RequestPlaneServer: Send + Sync {
/// Register an endpoint handler with the server
///
/// # Arguments
///
/// * `endpoint_name` - Name/path for routing (e.g., "generate", "health")
/// * `service_handler` - Handler that processes incoming requests
/// * `instance_id` - Unique instance identifier for this endpoint
/// * `namespace` - Service namespace (e.g., "dynamo")
/// * `component_name` - Component name (e.g., "backend", "frontend")
/// * `system_health` - Health tracking for this endpoint
///
/// # Returns
///
/// Returns `Ok(())` if registration succeeds, or an error if:
/// - Endpoint name is already registered
/// - Server is not running or has been stopped
/// - Transport-specific errors occur
async fn register_endpoint(
&self,
endpoint_name: String,
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()>;
/// Unregister an endpoint from the server
///
/// # Arguments
///
/// * `endpoint_name` - Name of the endpoint to unregister
///
/// # Returns
///
/// Returns `Ok(())` if unregistration succeeds or endpoint doesn't exist.
/// Errors are only returned for transport-specific failures.
async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()>;
/// Get server bind address or identifier
///
/// Returns a transport-specific address string:
/// - HTTP: `"http://0.0.0.0:8888"`
/// - TCP: `"tcp://0.0.0.0:9999"`
/// - NATS: `"nats://localhost:4222"`
///
/// Used for logging, debugging, and service discovery.
fn address(&self) -> String;
/// Get the transport name
///
/// Returns a static string identifier for the transport type.
/// Used for logging and debugging.
///
/// # Examples
///
/// - `"http"` - HTTP/2 transport
/// - `"tcp"` - Raw TCP transport
/// - `"nats"` - NATS messaging
fn transport_name(&self) -> &'static str;
/// Check if server is healthy and ready to accept requests
///
/// Returns `true` if the server is operational and can handle requests.
/// This is a lightweight check that doesn't perform actual network I/O.
///
/// Implementations should return `false` if:
/// - Server has been explicitly stopped
/// - Underlying transport is disconnected
/// - Server encountered a fatal error
fn is_healthy(&self) -> bool;
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Network Manager - Single Source of Truth for Network Configuration
//!
//! This module consolidates ALL network-related configuration and creation logic.
//! It is the ONLY place in the codebase that:
//! - Reads environment variables for network configuration
//! - Knows about transport-specific types (SharedHttpServer, TcpRequestClient, etc.)
//! - Performs mode selection based on RequestPlaneMode
//! - Creates servers and clients
//!
//! The rest of the codebase works exclusively with trait objects and never
//! directly accesses transport implementations or configuration.
use super::egress::unified_client::RequestPlaneClient;
use super::ingress::unified_server::RequestPlaneServer;
use crate::config::RequestPlaneMode;
use anyhow::Result;
use async_once_cell::OnceCell;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
/// Network configuration loaded from environment variables
#[derive(Clone)]
struct NetworkConfig {
// HTTP server configuration
http_host: String,
http_port: u16,
http_rpc_root: String,
// TCP server configuration
tcp_host: String,
tcp_port: u16,
// HTTP client configuration
http_client_config: super::egress::http_router::Http2Config,
// TCP client configuration
tcp_client_config: super::egress::tcp_client::TcpRequestConfig,
// NATS configuration (provided externally, not from env)
nats_client: Option<async_nats::Client>,
}
impl NetworkConfig {
/// Load configuration from environment variables
///
/// This is the ONLY place where network-related environment variables are read.
fn from_env(nats_client: Option<async_nats::Client>) -> Self {
Self {
// HTTP server configuration
http_host: std::env::var("DYN_HTTP_RPC_HOST")
.unwrap_or_else(|_| crate::utils::get_http_rpc_host_from_env()),
http_port: std::env::var("DYN_HTTP_RPC_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(8888),
http_rpc_root: std::env::var("DYN_HTTP_RPC_ROOT_PATH")
.unwrap_or_else(|_| "/v1/rpc".to_string()),
// TCP server configuration
tcp_host: std::env::var("DYN_TCP_RPC_HOST")
.unwrap_or_else(|_| crate::utils::get_tcp_rpc_host_from_env()),
tcp_port: std::env::var("DYN_TCP_RPC_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(9999),
// HTTP client configuration (reads DYN_HTTP2_* env vars)
http_client_config: super::egress::http_router::Http2Config::from_env(),
// TCP client configuration (reads DYN_TCP_* env vars)
tcp_client_config: super::egress::tcp_client::TcpRequestConfig::from_env(),
// NATS (external)
nats_client,
}
}
}
/// Network Manager - Central coordinator for all network resources
///
/// # Responsibilities
///
/// 1. **Configuration Management**: Reads and manages all network-related environment variables
/// 2. **Server Creation**: Creates and starts request plane servers based on mode
/// 3. **Client Creation**: Creates request plane clients on demand
/// 4. **Abstraction**: Hides all transport-specific details from the rest of the codebase
///
/// # Design Principles
///
/// - **Single Source of Truth**: All network config and creation logic lives here
/// - **Lazy Initialization**: Servers are created only when first accessed
/// - **Transport Agnostic Interface**: Exposes only trait objects to callers
/// - **No Leaky Abstractions**: Transport types never escape this module
///
/// # Example
///
/// ```ignore
/// // Create manager (typically done once in DistributedRuntime)
/// let manager = NetworkManager::new(cancel_token, nats_client, component_registry);
///
/// // Get server (lazy init, cached)
/// let server = manager.server().await?;
/// server.register_endpoint(...).await?;
///
/// // Create client (not cached, lightweight)
/// let client = manager.create_client()?;
/// client.send_request(...).await?;
/// ```
pub struct NetworkManager {
mode: RequestPlaneMode,
config: NetworkConfig,
server: Arc<OnceCell<Arc<dyn RequestPlaneServer>>>,
cancellation_token: CancellationToken,
component_registry: crate::component::Registry,
}
impl NetworkManager {
/// Create a new network manager
///
/// This is the single constructor for NetworkManager. All configuration
/// is loaded from environment variables internally.
///
/// # Arguments
///
/// * `cancellation_token` - Token for graceful shutdown of servers
/// * `nats_client` - Optional NATS client (required only for NATS mode)
/// * `component_registry` - Component registry to get NATS service groups from
///
/// # Returns
///
/// Returns an Arc-wrapped NetworkManager ready to create servers and clients.
pub fn new(
cancellation_token: CancellationToken,
nats_client: Option<async_nats::Client>,
component_registry: crate::component::Registry,
) -> Arc<Self> {
let mode = RequestPlaneMode::get();
let config = NetworkConfig::from_env(nats_client);
tracing::info!(
mode = %mode,
http_port = config.http_port,
tcp_port = config.tcp_port,
"Initializing NetworkManager"
);
Arc::new(Self {
mode,
config,
server: Arc::new(OnceCell::new()),
cancellation_token,
component_registry,
})
}
/// Get or create the request plane server
///
/// The server is created lazily on first access and cached for subsequent calls.
/// The server is automatically started in the background.
///
/// # Returns
///
/// Returns a trait object that abstracts over HTTP/TCP/NATS implementations.
///
/// # Errors
///
/// Returns an error if:
/// - Server creation fails (e.g., port already in use)
/// - NATS mode is selected but NATS client is not available
/// - Configuration is invalid (e.g., malformed bind address)
pub async fn server(&self) -> Result<Arc<dyn RequestPlaneServer>> {
let server = self
.server
.get_or_try_init(async { self.create_server().await })
.await?;
Ok(server.clone())
}
/// Create a new request plane client
///
/// Clients are lightweight and not cached. Each call creates a new client instance.
///
/// # Returns
///
/// Returns a trait object that abstracts over HTTP/TCP/NATS implementations.
///
/// # Errors
///
/// Returns an error if:
/// - Client creation fails (e.g., invalid configuration)
/// - NATS mode is selected but NATS client is not available
pub fn create_client(&self) -> Result<Arc<dyn RequestPlaneClient>> {
match self.mode {
RequestPlaneMode::Http => self.create_http_client(),
RequestPlaneMode::Tcp => self.create_tcp_client(),
RequestPlaneMode::Nats => self.create_nats_client(),
}
}
/// Get the current request plane mode
///
/// This is provided primarily for logging and debugging purposes.
/// Application logic should not branch on mode - use trait objects instead.
pub fn mode(&self) -> RequestPlaneMode {
self.mode
}
// ============================================================================
// PRIVATE: Server Creation
// ============================================================================
async fn create_server(&self) -> Result<Arc<dyn RequestPlaneServer>> {
match self.mode {
RequestPlaneMode::Http => self.create_http_server().await,
RequestPlaneMode::Tcp => self.create_tcp_server().await,
RequestPlaneMode::Nats => self.create_nats_server().await,
}
}
async fn create_http_server(&self) -> Result<Arc<dyn RequestPlaneServer>> {
use super::ingress::http_endpoint::SharedHttpServer;
let bind_addr = format!("{}:{}", self.config.http_host, self.config.http_port)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid HTTP bind address: {}", e))?;
tracing::info!(
bind_addr = %bind_addr,
rpc_root = %self.config.http_rpc_root,
"Creating HTTP request plane server"
);
let server = SharedHttpServer::new(bind_addr, self.cancellation_token.clone());
// Start server in background
let server_clone = server.clone();
tokio::spawn(async move {
if let Err(e) = server_clone.start().await {
tracing::error!("HTTP request plane server error: {}", e);
}
});
Ok(server as Arc<dyn RequestPlaneServer>)
}
async fn create_tcp_server(&self) -> Result<Arc<dyn RequestPlaneServer>> {
use super::ingress::shared_tcp_endpoint::SharedTcpServer;
let bind_addr = format!("{}:{}", self.config.tcp_host, self.config.tcp_port)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid TCP bind address: {}", e))?;
tracing::info!(
bind_addr = %bind_addr,
"Creating TCP request plane server"
);
let server = SharedTcpServer::new(bind_addr, self.cancellation_token.clone());
// Start server in background
let server_clone = server.clone();
tokio::spawn(async move {
if let Err(e) = server_clone.start().await {
tracing::error!("TCP request plane server error: {}", e);
}
});
Ok(server as Arc<dyn RequestPlaneServer>)
}
async fn create_nats_server(&self) -> Result<Arc<dyn RequestPlaneServer>> {
use super::ingress::nats_server::NatsMultiplexedServer;
let nats_client = self
.config
.nats_client
.as_ref()
.ok_or_else(|| anyhow::anyhow!("NATS client required for NATS mode"))?;
tracing::info!("Creating NATS request plane server");
Ok(NatsMultiplexedServer::new(
nats_client.clone(),
self.component_registry.clone(),
self.cancellation_token.clone(),
) as Arc<dyn RequestPlaneServer>)
}
// ============================================================================
// PRIVATE: Client Creation
// ============================================================================
fn create_http_client(&self) -> Result<Arc<dyn RequestPlaneClient>> {
use super::egress::http_router::HttpRequestClient;
tracing::debug!("Creating HTTP request plane client with config from NetworkManager");
Ok(Arc::new(HttpRequestClient::with_config(
self.config.http_client_config.clone(),
)?))
}
fn create_tcp_client(&self) -> Result<Arc<dyn RequestPlaneClient>> {
use super::egress::tcp_client::TcpRequestClient;
tracing::debug!("Creating TCP request plane client with config from NetworkManager");
Ok(Arc::new(TcpRequestClient::with_config(
self.config.tcp_client_config.clone(),
)?))
}
fn create_nats_client(&self) -> Result<Arc<dyn RequestPlaneClient>> {
use super::egress::nats_client::NatsRequestClient;
let nats_client = self
.config
.nats_client
.as_ref()
.ok_or_else(|| anyhow::anyhow!("NATS client required for NATS mode"))?;
tracing::debug!("Creating NATS request plane client");
Ok(Arc::new(NatsRequestClient::new(nats_client.clone())))
}
}
......@@ -27,7 +27,7 @@ use tokio::{signal, sync::Mutex, task::JoinHandle};
pub use tokio_util::sync::CancellationToken;
/// Types of Tokio runtimes that can be used to construct a Dynamo [Runtime].
#[derive(Clone)]
#[derive(Clone, Debug)]
enum RuntimeType {
Shared(Arc<tokio::runtime::Runtime>),
External(tokio::runtime::Handle),
......@@ -339,12 +339,3 @@ impl RuntimeType {
}
}
}
impl std::fmt::Debug for RuntimeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RuntimeType::External(_) => write!(f, "RuntimeType::External"),
RuntimeType::Shared(_) => write!(f, "RuntimeType::Shared"),
}
}
}
......@@ -798,9 +798,7 @@ mod integration_tests {
endpoint: "health".to_string(),
namespace: "test_namespace".to_string(),
instance_id: 1,
transport: crate::component::TransportType::NatsTcp(
endpoint.to_string(),
),
transport: crate::component::TransportType::Nats(endpoint.to_string()),
},
health_check_payload.clone(),
);
......
......@@ -400,12 +400,6 @@ impl Default for NatsAuth {
}
}
/// Is this file name / url in the NATS object store?
/// Checks the name only, does not go to the store.
pub fn is_nats_url(s: &str) -> bool {
s.starts_with(URL_PREFIX)
}
/// Extract NATS bucket and key from a nats URL of the form:
/// nats://host[:port]/bucket/key
pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment