Unverified Commit 63d7c01c authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: velo-backend (#6547)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent 2cc92bfa
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! High-performance TCP listener for ActiveMessage transport
//!
//! This module provides a TCP server that accepts incoming connections,
//! decodes framed messages using zero-copy techniques, and routes them
//! to the appropriate transport streams.
use anyhow::{Context, Result};
use bytes::Bytes;
use futures::StreamExt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener as TokioTcpListener, TcpStream};
use tokio::runtime::{Handle, Runtime};
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};
use crate::{MessageType, ShutdownState, TransportAdapter, TransportErrorHandler};
use super::framing::TcpFrameCodec;
/// Runtime configuration for the TCP listener
pub enum RuntimeConfig {
/// Use an existing tokio runtime handle.
Handle(Handle),
/// Use a provided tokio runtime.
Runtime(Arc<Runtime>),
/// Create a single-threaded runtime pinned to the specified CPU core (Linux only).
CpuPin(usize),
}
/// High-performance TCP listener for ActiveMessage transport
///
/// This listener accepts incoming TCP connections and routes decoded frames
/// to the appropriate transport streams with zero-copy performance.
pub struct TcpListener {
bind_addr: SocketAddr,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
runtime_config: RuntimeConfig,
listener: Option<std::net::TcpListener>,
}
impl TcpListener {
/// Create a new builder for TcpListener
pub fn builder() -> TcpListenerBuilder {
TcpListenerBuilder::new()
}
/// Start the listener and serve incoming connections
///
/// This method blocks or spawns based on the runtime configuration:
/// - For Handle/Runtime: spawns tasks and returns immediately
/// - For CpuPin: creates a pinned runtime and blocks until cancellation
pub async fn serve(mut self) -> Result<()> {
// Extract runtime config to avoid borrow issues
let runtime_config = std::mem::replace(
&mut self.runtime_config,
RuntimeConfig::Handle(Handle::current()),
);
match runtime_config {
RuntimeConfig::Handle(handle) => {
handle.spawn(async move {
if let Err(e) = self.run_server().await {
error!("TCP listener error: {}", e);
}
});
Ok(())
}
RuntimeConfig::Runtime(rt) => {
rt.spawn(async move {
if let Err(e) = self.run_server().await {
error!("TCP listener error: {}", e);
}
});
Ok(())
}
RuntimeConfig::CpuPin(cpu_id) => {
let rt = Self::create_pinned_runtime(cpu_id)
.context("Failed to create CPU-pinned runtime")?;
rt.block_on(self.run_server())
}
}
}
/// Create a single-threaded runtime pinned to a specific CPU core
#[cfg(target_os = "linux")]
fn create_pinned_runtime(cpu_id: usize) -> Result<Runtime> {
use nix::sched::{CpuSet, sched_setaffinity};
use nix::unistd::Pid;
tokio::runtime::Builder::new_current_thread()
.enable_all()
.thread_name("tcp-listener-pinned")
.on_thread_start(move || {
let mut cpu_set = CpuSet::new();
if cpu_set.set(cpu_id).is_ok() {
if let Err(e) = sched_setaffinity(Pid::from_raw(0), &cpu_set) {
error!("Failed to pin thread to CPU {}: {}", cpu_id, e);
} else {
debug!("Successfully pinned TCP listener to CPU {}", cpu_id);
}
}
})
.build()
.context("Failed to build tokio runtime")
}
/// Create a single-threaded runtime without CPU pinning (non-Linux platforms)
#[cfg(not(target_os = "linux"))]
fn create_pinned_runtime(cpu_id: usize) -> Result<Runtime> {
warn!(
"CPU pinning requested (CPU {}) but not supported on this platform",
cpu_id
);
tokio::runtime::Builder::new_current_thread()
.enable_all()
.thread_name("tcp-listener")
.build()
.context("Failed to build tokio runtime")
}
/// Main server loop that accepts connections
async fn run_server(self) -> Result<()> {
// Use pre-bound listener if provided, otherwise bind to the address
let listener = if let Some(std_listener) = self.listener {
// Set non-blocking for tokio conversion
std_listener
.set_nonblocking(true)
.context("Failed to set listener to non-blocking")?;
TokioTcpListener::from_std(std_listener)
.context("Failed to convert std TcpListener to tokio TcpListener")?
} else {
TokioTcpListener::bind(self.bind_addr)
.await
.context(format!("Failed to bind TCP listener to {}", self.bind_addr))?
};
let local_addr = listener
.local_addr()
.context("Failed to get local address")?;
info!("TCP listener bound to {}", local_addr);
let teardown_token = self.shutdown_state.teardown_token().clone();
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, peer_addr)) => {
debug!("Accepted TCP connection from {}", peer_addr);
let adapter = self.adapter.clone();
let error_handler = self.error_handler.clone();
let shutdown_state = self.shutdown_state.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(
stream,
peer_addr,
adapter,
error_handler,
shutdown_state,
)
.await
{
warn!("Error handling connection from {}: {}", peer_addr, e);
}
});
}
Err(e) => {
error!("Failed to accept TCP connection: {}", e);
}
}
}
_ = teardown_token.cancelled() => {
info!("TCP listener shutting down (teardown)");
break;
}
}
}
Ok(())
}
/// Handle a single TCP connection
async fn handle_connection(
stream: TcpStream,
peer_addr: SocketAddr,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
) -> Result<()> {
debug!("Configuring connection from {}", peer_addr);
// Configure socket for high performance
if let Err(e) = stream.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY on {}: {}", peer_addr, e);
}
#[allow(deprecated)] // Intentional: linger ensures clean socket shutdown
if let Err(e) = stream.set_linger(Some(Duration::from_secs(1))) {
warn!("Failed to set linger on {}: {}", peer_addr, e);
}
// Set keep-alive to detect dead connections
let keepalive = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10));
let sock_ref = socket2::SockRef::from(&stream);
if let Err(e) = sock_ref.set_tcp_keepalive(&keepalive) {
warn!("Failed to set TCP keepalive on {}: {}", peer_addr, e);
}
// Set large receive buffer for high throughput
if let Err(e) = sock_ref.set_recv_buffer_size(1_048_576) {
warn!("Failed to set receive buffer size on {}: {}", peer_addr, e);
}
// Create framed stream with zero-copy codec
let mut framed = Framed::new(stream, TcpFrameCodec::new());
let teardown_token = shutdown_state.teardown_token().clone();
debug!("Connection from {} ready for frames", peer_addr);
loop {
tokio::select! {
frame_result = framed.next() => {
match frame_result {
Some(Ok((msg_type, header, payload))) => {
// During drain: reject new Message frames with ShuttingDown,
// but always pass through Response/Ack/Event frames.
if shutdown_state.is_draining() && msg_type == MessageType::Message {
debug!(
"Rejecting Message frame from {} during drain (sending ShuttingDown)",
peer_addr
);
// Echo original header back for correlation, empty payload
if let Err(e) = TcpFrameCodec::encode_frame(
framed.get_mut(),
MessageType::ShuttingDown,
&header,
&[],
)
.await
{
warn!(
"Failed to send ShuttingDown frame to {}: {}",
peer_addr, e
);
}
continue;
}
// Route frame to appropriate stream based on type
if let Err(e) = Self::route_frame(
msg_type,
header,
payload,
&adapter,
&error_handler,
)
.await
{
warn!(
"Failed to route {:?} frame from {}: {}",
msg_type, peer_addr, e
);
}
}
Some(Err(e)) => {
error!("Frame decode error from {}: {}", peer_addr, e);
break;
}
None => {
// Connection closed gracefully (FIN received)
debug!("Connection from {} closed gracefully", peer_addr);
break;
}
}
}
_ = teardown_token.cancelled() => {
debug!("Connection handler for {} torn down", peer_addr);
break;
}
}
}
Ok(())
}
/// Route a decoded frame to the appropriate stream
///
/// This function performs zero-copy routing by transferring ownership of
/// the Bytes to the flume channel. On error, it invokes the error callback
/// with the original data (requiring a clone).
async fn route_frame(
msg_type: MessageType,
header: Bytes,
payload: Bytes,
adapter: &TransportAdapter,
error_handler: &Arc<dyn TransportErrorHandler>,
) -> Result<()> {
let sender = match msg_type {
MessageType::Message => &adapter.message_stream,
MessageType::Response => &adapter.response_stream,
MessageType::Ack | MessageType::Event => &adapter.event_stream,
MessageType::ShuttingDown => {
// ShuttingDown is an outbound-only frame type; receiving it here
// means a remote peer rejected our request. Route to the response
// stream so higher layers can handle the rejection via correlation.
&adapter.response_stream
}
};
// Try to send with ownership transfer (zero-copy)
match sender.send_async((header, payload)).await {
Ok(_) => Ok(()),
Err(e) => {
// Send failed - invoke error callback with the data
error_handler.on_error(
e.0.0, // header
e.0.1, // payload
format!("Failed to route {:?}", msg_type),
);
Err(anyhow::anyhow!("Failed to send to stream"))
}
}
}
}
/// Builder for TcpListener
pub struct TcpListenerBuilder {
bind_addr: Option<SocketAddr>,
adapter: Option<TransportAdapter>,
error_handler: Option<Arc<dyn TransportErrorHandler>>,
shutdown_state: Option<ShutdownState>,
runtime_config: Option<RuntimeConfig>,
listener: Option<std::net::TcpListener>,
}
impl TcpListenerBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
bind_addr: None,
adapter: None,
error_handler: None,
shutdown_state: None,
runtime_config: None,
listener: None,
}
}
/// Set the bind address
pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
self.bind_addr = Some(addr);
self
}
/// Set the transport adapter
pub fn adapter(mut self, adapter: TransportAdapter) -> Self {
self.adapter = Some(adapter);
self
}
/// Set the error handler
pub fn error_handler(mut self, handler: Arc<dyn TransportErrorHandler>) -> Self {
self.error_handler = Some(handler);
self
}
/// Set the shutdown state for graceful drain coordination
pub fn shutdown_state(mut self, state: ShutdownState) -> Self {
self.shutdown_state = Some(state);
self
}
/// Use an existing tokio runtime handle
pub fn with_handle(mut self, handle: Handle) -> Self {
self.runtime_config = Some(RuntimeConfig::Handle(handle));
self
}
/// Use a provided tokio runtime
pub fn with_runtime(mut self, runtime: Arc<Runtime>) -> Self {
self.runtime_config = Some(RuntimeConfig::Runtime(runtime));
self
}
/// Create a single-threaded runtime pinned to a specific CPU core
pub fn with_cpu_pin(mut self, cpu_id: usize) -> Self {
self.runtime_config = Some(RuntimeConfig::CpuPin(cpu_id));
self
}
/// Use a pre-bound TcpListener
///
/// This is useful for tests where you want to bind to port 0 and avoid port races.
/// When provided, the bind_addr should still be set (for logging/debugging purposes).
pub fn listener(mut self, listener: Option<std::net::TcpListener>) -> Self {
self.listener = listener;
self
}
/// Build the TcpListener
pub fn build(self) -> Result<TcpListener> {
let bind_addr = self
.bind_addr
.ok_or_else(|| anyhow::anyhow!("bind_addr is required"))?;
let adapter = self
.adapter
.ok_or_else(|| anyhow::anyhow!("adapter is required"))?;
let error_handler = self
.error_handler
.ok_or_else(|| anyhow::anyhow!("error_handler is required"))?;
let shutdown_state = self.shutdown_state.unwrap_or_default();
let runtime_config = self
.runtime_config
.unwrap_or_else(|| RuntimeConfig::Handle(Handle::current()));
Ok(TcpListener {
bind_addr,
adapter,
error_handler,
shutdown_state,
runtime_config,
listener: self.listener,
})
}
}
impl Default for TcpListenerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::make_channels;
use std::net::{IpAddr, Ipv4Addr};
struct TestErrorHandler;
impl TransportErrorHandler for TestErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
eprintln!("Test error handler: {}", error);
}
}
#[test]
fn test_builder_requires_fields() {
let result = TcpListener::builder().build();
assert!(result.is_err());
}
#[tokio::test]
async fn test_builder_with_all_fields() {
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
let (adapter, _streams) = make_channels();
let error_handler = Arc::new(TestErrorHandler);
let result = TcpListener::builder()
.bind_addr(bind_addr)
.adapter(adapter)
.error_handler(error_handler)
.build();
assert!(result.is_ok());
}
#[test]
fn test_builder_with_cpu_pin() {
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
let (adapter, _streams) = make_channels();
let error_handler = Arc::new(TestErrorHandler);
let result = TcpListener::builder()
.bind_addr(bind_addr)
.adapter(adapter)
.error_handler(error_handler)
.with_cpu_pin(0)
.build();
assert!(result.is_ok());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! TCP Transport Module
//!
//! This module provides a high-performance TCP transport implementation with:
//! - Zero-copy frame codec for minimal overhead
//! - CPU pinning support for predictable latency
//! - Frame type routing (Message, Response, Ack, Event)
//! - Graceful shutdown with proper FIN handling
//! - Keep-alive for dead connection detection
mod framing;
mod listener;
mod transport;
pub use framing::TcpFrameCodec;
pub use listener::{RuntimeConfig, TcpListener, TcpListenerBuilder};
pub use transport::{TcpTransport, TcpTransportBuilder};
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! High-performance TCP transport with single-threaded optimizations
//!
//! This implementation uses Rc+RefCell+LocalSet for maximum performance on a single CPU core.
//! All operations run on the same thread as the TCP listener for optimal cache locality.
use anyhow::{Context, Result};
use bytes::Bytes;
use dashmap::DashMap;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use crate::transport::{HealthCheckError, ShutdownState, TransportError, TransportErrorHandler};
use crate::{MessageType, PeerInfo, Transport, TransportAdapter, TransportKey, WorkerAddress};
use super::framing::TcpFrameCodec;
use super::listener::TcpListener;
/// High-performance TCP transport with lock-free concurrent access
///
/// This transport uses `DashMap` for lock-free concurrent access to connection state.
/// Tasks are spawned using `tokio::spawn` for compatibility with the `Transport` trait.
/// For single-threaded performance, run the entire transport in a `LocalSet` context.
pub struct TcpTransport {
// Identity (immutable, no wrapper needed)
key: TransportKey,
bind_addr: SocketAddr,
local_address: WorkerAddress,
// Shared mutable state with DashMap (lock-free)
peers: Arc<DashMap<crate::InstanceId, SocketAddr>>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
// Runtime handle for spawning tasks
runtime: OnceLock<tokio::runtime::Handle>,
// Shutdown coordination
cancel_token: CancellationToken,
shutdown_state: OnceLock<ShutdownState>,
// Send channel capacity for backpressure
channel_capacity: usize,
// Optional pre-bound listener (used for tests to avoid port races)
listener: Mutex<Option<std::net::TcpListener>>,
}
/// Handle to a connection's writer task
#[derive(Clone)]
struct ConnectionHandle {
tx: flume::Sender<SendTask>,
}
/// Task sent to writer task containing pre-encoded frame
struct SendTask {
msg_type: MessageType,
header: Bytes,
payload: Bytes,
on_error: Arc<dyn TransportErrorHandler>,
}
impl SendTask {
fn on_error(self, error: impl Into<String>) {
self.on_error
.on_error(self.header, self.payload, error.into());
}
}
impl TcpTransport {
/// Create a new TCP transport bound to `bind_addr` with the given transport key.
///
/// An optional pre-bound `listener` can be provided (useful for tests binding
/// to port 0). `channel_capacity` controls backpressure on per-connection
/// writer channels (default 256).
pub fn new(
bind_addr: SocketAddr,
key: TransportKey,
local_address: WorkerAddress,
channel_capacity: usize,
listener: Option<std::net::TcpListener>,
) -> Self {
Self {
key,
bind_addr,
local_address,
peers: Arc::new(DashMap::new()),
connections: Arc::new(DashMap::new()),
runtime: OnceLock::new(),
cancel_token: CancellationToken::new(),
shutdown_state: OnceLock::new(),
channel_capacity,
listener: Mutex::new(listener),
}
}
/// Optional: Pre-establish connection after registration
///
/// This can be called after `register()` to eagerly establish the TCP connection
/// instead of waiting for the first `send_message()` call.
pub fn ensure_connected(&self, instance_id: crate::InstanceId) -> Result<()> {
self.get_or_create_connection(instance_id)?;
Ok(())
}
/// Get or create a connection to a peer (lazy initialization)
fn get_or_create_connection(&self, instance_id: crate::InstanceId) -> Result<ConnectionHandle> {
// Fast path: connection already exists and is alive
if let Some(handle) = self.connections.get(&instance_id) {
if !handle.tx.is_disconnected() {
return Ok(handle.clone());
}
// Stale — drop guard before mutating the map
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
}
let rt = self.runtime.get().ok_or(TransportError::NotStarted)?;
// Atomic check-and-insert via entry API
let handle = match self.connections.entry(instance_id) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
if !entry.get().tx.is_disconnected() {
entry.get().clone()
} else {
// Stale entry — replace in-place with a fresh connection
let handle = self.create_connection(instance_id, rt)?;
entry.insert(handle.clone());
handle
}
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
let handle = self.create_connection(instance_id, rt)?;
entry.insert(handle.clone());
handle
}
};
Ok(handle)
}
/// Create a new connection handle and spawn the writer task.
fn create_connection(
&self,
instance_id: crate::InstanceId,
rt: &tokio::runtime::Handle,
) -> Result<ConnectionHandle> {
let addr = *self
.peers
.get(&instance_id)
.ok_or(TransportError::PeerNotRegistered(instance_id))?
.value();
let (tx, rx) = flume::bounded(self.channel_capacity);
let handle = ConnectionHandle { tx };
let cancel = self.cancel_token.clone();
let conns = Arc::clone(&self.connections);
rt.spawn(connection_writer_task(addr, instance_id, rx, conns, cancel));
debug!("Created new connection to {} ({})", instance_id, addr);
Ok(handle)
}
}
impl Transport for TcpTransport {
fn key(&self) -> TransportKey {
self.key.clone()
}
fn address(&self) -> WorkerAddress {
self.local_address.clone()
}
fn register(&self, peer_info: PeerInfo) -> Result<(), TransportError> {
// Get endpoint from peer's address
let endpoint = peer_info
.worker_address()
.get_entry(&self.key)
.map_err(|_| TransportError::NoEndpoint)?
.ok_or(TransportError::NoEndpoint)?;
// Parse TCP endpoint (expected format: "tcp://host:port" or "host:port")
let addr = parse_tcp_endpoint(&endpoint).map_err(|e| {
error!("Failed to parse TCP endpoint: {}", e);
TransportError::InvalidEndpoint
})?;
// Store peer address
self.peers.insert(peer_info.instance_id(), addr);
debug!("Registered peer {} at {}", peer_info.instance_id(), addr);
Ok(())
}
#[inline]
fn send_message(
&self,
instance_id: crate::InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: std::sync::Arc<dyn TransportErrorHandler>,
) {
// Convert to Bytes (one allocation each)
let header = Bytes::from(header);
let payload = Bytes::from(payload);
let send_msg = SendTask {
msg_type: message_type,
header,
payload,
on_error,
};
// Fast path: try to send on existing connection
let send_msg = match self.connections.get(&instance_id) {
Some(handle) => match handle.tx.try_send(send_msg) {
Ok(()) => return,
Err(flume::TrySendError::Full(send_msg)) => send_msg,
Err(flume::TrySendError::Disconnected(send_msg)) => {
// Drop the guard before mutating the map
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
// Fall through to slow path to create a fresh connection
send_msg
}
},
None => send_msg,
};
// Slow path: create new connection
let rt = match self.runtime.get() {
Some(rt) => rt,
None => {
send_msg.on_error("Transport not started");
return;
}
};
let handle = match self.get_or_create_connection(instance_id) {
Ok(h) => h,
Err(e) => {
send_msg.on_error(format!("Failed to create connection: {}", e));
return;
}
};
rt.spawn(async move {
if let Err(flume::SendError(send_msg)) = handle.tx.send_async(send_msg).await {
send_msg.on_error("Connection closed");
}
});
}
fn start(
&self,
_instance_id: crate::InstanceId,
channels: TransportAdapter,
rt: tokio::runtime::Handle,
) -> futures::future::BoxFuture<'_, anyhow::Result<()>> {
// Store runtime handle for use in send_message
self.runtime.set(rt.clone()).ok();
// Capture shutdown state from the adapter
self.shutdown_state
.set(channels.shutdown_state.clone())
.ok();
let bind_addr = self.bind_addr;
let shutdown_state = channels.shutdown_state.clone();
// Take ownership of the listener (if present) - we can only start once
let listener = self
.listener
.lock()
.expect("Listener mutex poisoned")
.take();
Box::pin(async move {
// Create error handler that routes to the transport error handler
struct DefaultErrorHandler;
impl TransportErrorHandler for DefaultErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
warn!("Transport error: {}", error);
}
}
// Start TCP listener
let tcp_listener = TcpListener::builder()
.bind_addr(bind_addr)
.adapter(channels)
.error_handler(std::sync::Arc::new(DefaultErrorHandler))
.shutdown_state(shutdown_state)
.listener(listener)
.build()?;
rt.spawn(async move {
if let Err(e) = tcp_listener.serve().await {
error!("TCP listener error: {}", e);
}
});
info!("TCP transport started on {}", bind_addr);
Ok(())
})
}
fn begin_drain(&self) {
// Per-frame gate in the listener handles drain — no-op here.
}
fn shutdown(&self) {
info!("Shutting down TCP transport");
// Cancel the teardown token (Phase 3) to stop the listener and connection handlers
if let Some(state) = self.shutdown_state.get() {
state.teardown_token().cancel();
}
self.cancel_token.cancel();
// Clear connections
self.connections.clear();
}
fn check_health(
&self,
instance_id: crate::InstanceId,
timeout: Duration,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), HealthCheckError>> + Send + '_>,
> {
Box::pin(async move {
// Check if we have an existing connection
let connection_exists = self.connections.contains_key(&instance_id);
if let Some(handle) = self.connections.get(&instance_id) {
// Check if the channel is still connected (socket is still live)
// If the writer task has exited (socket closed), the channel will be disconnected
if !handle.tx.is_disconnected() {
return Ok(()); // Connection is alive and healthy
}
// Channel is disconnected — drop guard and remove stale entry
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
}
// No existing connection or connection is dead - verify peer is reachable
let addr = *self
.peers
.get(&instance_id)
.ok_or(HealthCheckError::PeerNotRegistered)?
.value();
// Try to connect (and immediately drop) to verify peer is reachable
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(_stream)) => {
// Connection successful, drop immediately
// If we never had a connection before, report NeverConnected
// If we had one before that failed, report Ok (peer is reachable now)
if connection_exists {
Ok(())
} else {
Err(HealthCheckError::NeverConnected)
}
}
Ok(Err(_)) => Err(HealthCheckError::ConnectionFailed),
Err(_) => Err(HealthCheckError::Timeout),
}
})
}
}
/// Connection writer task
///
/// This task runs on the LocalSet and handles writing framed bytes to the TCP stream.
/// It receives pre-encoded frames via a flume channel and writes them to the socket.
///
/// Cleanup (draining queued messages and removing the stale map entry) always runs,
/// even if the initial TCP connect fails.
async fn connection_writer_task(
addr: SocketAddr,
instance_id: crate::InstanceId,
rx: flume::Receiver<SendTask>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
_cancel_token: CancellationToken,
) -> Result<()> {
let result = connection_writer_inner(addr, instance_id, &rx).await;
// Always drain queued messages and notify their error handlers.
//
// TODO: There is a tiny race between the drain finishing and `drop(rx)`:
// a sender on another thread could `try_send` successfully in that window,
// and the message would be silently dropped when rx is destroyed. Closing
// this fully would require swapping the map entry with a "poisoned" handle
// (a disconnected tx) before draining, so fast-path senders see a failure
// instead. Not worth the complexity today — at most one message is affected,
// and async senders already get `SendError` once rx is dropped.
while let Ok(msg) = rx.try_recv() {
msg.on_error("Connection closed");
}
// Drop the receiver so our sender half becomes disconnected, then remove
// the stale entry. The predicate ensures we only remove our own entry —
// a replacement connection's tx will still be connected.
drop(rx);
connections.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
debug!("Connection to {} ({}) closed", instance_id, addr);
result
}
/// Inner loop: connect, configure the socket, and send frames until the channel
/// closes or a write error occurs.
async fn connection_writer_inner(
addr: SocketAddr,
instance_id: crate::InstanceId,
rx: &flume::Receiver<SendTask>,
) -> Result<()> {
debug!("Connecting to {}", addr);
let mut stream = TcpStream::connect(addr).await.context("connect failed")?;
if let Err(e) = stream.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY: {}", e);
}
let sock = socket2::SockRef::from(&stream);
if let Err(e) = sock.set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10)),
) {
warn!("Failed to set keepalive: {}", e);
}
if let Err(e) = sock.set_send_buffer_size(1_048_576) {
warn!("Failed to set send buffer size: {}", e);
}
debug!("Connected to {}", addr);
while let Ok(msg) = rx.recv_async().await {
if let Err(e) =
TcpFrameCodec::encode_frame(&mut stream, msg.msg_type, &msg.header, &msg.payload).await
{
error!("Write error to {} ({}): {}", instance_id, addr, e);
msg.on_error(format!("Failed to write to stream: {}", e));
break;
}
}
Ok(())
}
/// Parse a TCP endpoint string into a SocketAddr
///
/// Accepts formats:
/// - "tcp://host:port"
/// - "host:port"
fn parse_tcp_endpoint(endpoint: &[u8]) -> Result<SocketAddr> {
let endpoint_str = std::str::from_utf8(endpoint).context("endpoint is not valid UTF-8")?;
// Strip "tcp://" prefix if present
let addr_str = endpoint_str.strip_prefix("tcp://").unwrap_or(endpoint_str);
// Parse as socket address
let mut addrs = addr_str
.to_socket_addrs()
.context("failed to parse socket address")?;
addrs
.next()
.ok_or_else(|| anyhow::anyhow!("no addresses resolved"))
}
/// Resolve a wildcard bind address to a routable address for advertisement.
///
/// When binding to 0.0.0.0 (IPv4 unspecified) or :: (IPv6 unspecified),
/// we need to advertise a routable address that peers can actually connect to.
///
/// For 0.0.0.0, we use 127.0.0.1 (localhost) which works for same-machine communication.
/// For ::, we use ::1 (IPv6 localhost).
///
/// In a production multi-node deployment, this should be replaced with actual
/// network interface discovery or explicit configuration.
fn resolve_advertise_address(bind_addr: SocketAddr) -> SocketAddr {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
match bind_addr.ip() {
IpAddr::V4(ip) if ip.is_unspecified() => {
// 0.0.0.0 -> 127.0.0.1 for local testing
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), bind_addr.port())
}
IpAddr::V6(ip) if ip.is_unspecified() => {
// :: -> ::1 for local testing
SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), bind_addr.port())
}
_ => {
// Already a specific address, use as-is
bind_addr
}
}
}
/// Builder for TcpTransport
pub struct TcpTransportBuilder {
bind_addr: Option<SocketAddr>,
key: Option<TransportKey>,
channel_capacity: usize,
listener: Option<std::net::TcpListener>,
}
impl TcpTransportBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
bind_addr: None,
key: None,
channel_capacity: 256,
listener: None,
}
}
/// Set the bind address
pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
self.bind_addr = Some(addr);
self
}
/// Set the transport key
pub fn key(mut self, key: TransportKey) -> Self {
self.key = Some(key);
self
}
/// Set the channel capacity for backpressure (default: 256)
pub fn channel_capacity(mut self, capacity: usize) -> Self {
self.channel_capacity = capacity;
self
}
/// Use a pre-bound TcpListener instead of binding to a specific address
///
/// This is useful for tests where you want to bind to port 0 and get an OS-assigned
/// port without creating a race condition between binding and starting the transport.
///
/// Note: This is mutually exclusive with `bind_addr()`. Using both will result in an error.
pub fn from_listener(mut self, listener: std::net::TcpListener) -> Result<Self> {
// Validate mutual exclusivity: can't use both bind_addr() and from_listener()
if self.bind_addr.is_some() {
anyhow::bail!(
"Cannot use both bind_addr() and from_listener() - they are mutually exclusive"
);
}
let addr = listener
.local_addr()
.context("Failed to get local address from listener")?;
self.bind_addr = Some(addr);
self.listener = Some(listener);
Ok(self)
}
/// Build the TcpTransport
pub fn build(self) -> Result<TcpTransport> {
let bind_addr = self
.bind_addr
.ok_or_else(|| anyhow::anyhow!("bind_addr is required"))?;
let key = self.key.unwrap_or_else(|| TransportKey::from("tcp"));
// Resolve advertise address (handle 0.0.0.0 -> 127.0.0.1 for local testing)
let advertise_addr = resolve_advertise_address(bind_addr);
let local_endpoint = format!("tcp://{}", advertise_addr);
let mut addr_builder = crate::address::WorkerAddressBuilder::new();
addr_builder.add_entry(key.clone(), local_endpoint.as_bytes().to_vec())?;
let local_address = addr_builder.build()?;
Ok(TcpTransport::new(
bind_addr,
key,
local_address,
self.channel_capacity,
self.listener,
))
}
}
impl Default for TcpTransportBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::address::WorkerAddressBuilder;
use std::sync::atomic::{AtomicUsize, Ordering};
use velo_common::PeerInfo;
/// Error handler that discards errors (for tests that don't need to track them).
struct NullErrorHandler;
impl TransportErrorHandler for NullErrorHandler {
fn on_error(&self, _: Bytes, _: Bytes, _: String) {}
}
/// Error handler that counts errors (for tests that verify error routing).
struct TrackingErrorHandler {
count: AtomicUsize,
}
impl TrackingErrorHandler {
fn new() -> Self {
Self {
count: AtomicUsize::new(0),
}
}
fn error_count(&self) -> usize {
self.count.load(Ordering::SeqCst)
}
}
impl TransportErrorHandler for TrackingErrorHandler {
fn on_error(&self, _: Bytes, _: Bytes, _: String) {
self.count.fetch_add(1, Ordering::SeqCst);
}
}
/// Build a `PeerInfo` whose TCP endpoint points at `addr`.
fn make_tcp_peer(addr: SocketAddr) -> PeerInfo {
let instance_id = crate::InstanceId::new_v4();
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("tcp", format!("tcp://{}", addr).into_bytes())
.unwrap();
PeerInfo::new(instance_id, builder.build().unwrap())
}
/// Build a `TcpTransport` with its runtime set, bound to a real listener.
/// Returns `(transport, listener_addr)`.
fn make_transport() -> (TcpTransport, SocketAddr) {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let transport = TcpTransportBuilder::new()
.from_listener(listener)
.unwrap()
.build()
.unwrap();
// Set the runtime handle so `get_or_create_connection` can spawn tasks.
transport
.runtime
.set(tokio::runtime::Handle::current())
.ok();
(transport, addr)
}
/// Insert a stale `ConnectionHandle` into the transport's connections map.
/// A "stale" handle is one whose receiver has been dropped.
fn insert_stale_handle(transport: &TcpTransport, instance_id: crate::InstanceId) {
let (tx, _rx) = flume::bounded::<SendTask>(1);
// Drop _rx immediately so tx.is_disconnected() == true
transport
.connections
.insert(instance_id, ConnectionHandle { tx });
}
#[test]
fn test_parse_tcp_endpoint() {
// With tcp:// prefix
let addr = parse_tcp_endpoint(b"tcp://127.0.0.1:5555").unwrap();
assert_eq!(addr.port(), 5555);
// Without prefix
let addr = parse_tcp_endpoint(b"127.0.0.1:6666").unwrap();
assert_eq!(addr.port(), 6666);
// Invalid
assert!(parse_tcp_endpoint(b"invalid").is_err());
}
#[test]
fn test_builder_requires_bind_addr() {
let result = TcpTransportBuilder::new().build();
assert!(result.is_err());
}
#[test]
fn test_builder_with_bind_addr() {
let addr = "127.0.0.1:0".parse().unwrap();
let result = TcpTransportBuilder::new().bind_addr(addr).build();
assert!(result.is_ok());
}
#[test]
fn test_builder_with_listener() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let result = TcpTransportBuilder::new().from_listener(listener);
assert!(result.is_ok());
let result = result.unwrap().build();
assert!(result.is_ok());
}
#[test]
fn test_builder_bind_addr_and_listener_mutually_exclusive() {
let addr = "127.0.0.1:0".parse().unwrap();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let result = TcpTransportBuilder::new()
.bind_addr(addr)
.from_listener(listener);
assert!(result.is_err());
let err_msg = format!("{}", result.err().unwrap());
assert!(err_msg.contains("mutually exclusive"));
}
#[test]
fn test_resolve_advertise_address_ipv4_unspecified() {
use std::net::{IpAddr, Ipv4Addr};
// 0.0.0.0 should resolve to 127.0.0.1
let bind_addr: SocketAddr = "0.0.0.0:12345".parse().unwrap();
let resolved = resolve_advertise_address(bind_addr);
assert_eq!(resolved.ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
assert_eq!(resolved.port(), 12345);
// Already specific address should remain unchanged
let specific: SocketAddr = "192.168.1.100:8080".parse().unwrap();
let resolved = resolve_advertise_address(specific);
assert_eq!(resolved, specific);
}
#[test]
fn test_resolve_advertise_address_ipv6_unspecified() {
use std::net::{IpAddr, Ipv6Addr};
// :: should resolve to ::1
let bind_addr: SocketAddr = "[::]:12345".parse().unwrap();
let resolved = resolve_advertise_address(bind_addr);
assert_eq!(resolved.ip(), IpAddr::V6(Ipv6Addr::LOCALHOST));
assert_eq!(resolved.port(), 12345);
// Already specific IPv6 address should remain unchanged
let specific: SocketAddr = "[::1]:8080".parse().unwrap();
let resolved = resolve_advertise_address(specific);
assert_eq!(resolved, specific);
}
#[tokio::test]
async fn test_get_or_create_connection_replaces_stale_handle() {
let (transport, _our_addr) = make_transport();
// Start a listener that the transport can connect to
let peer_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let peer_addr = peer_listener.local_addr().unwrap();
let peer = make_tcp_peer(peer_addr);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert a stale handle
insert_stale_handle(&transport, iid);
assert!(
transport
.connections
.get(&iid)
.unwrap()
.tx
.is_disconnected()
);
// get_or_create_connection should replace the stale handle with a live one
let handle = transport.get_or_create_connection(iid).unwrap();
assert!(!handle.tx.is_disconnected());
// The map entry should also be live
let entry = transport.connections.get(&iid).unwrap();
assert!(!entry.tx.is_disconnected());
}
#[tokio::test]
async fn test_check_health_removes_stale_entry() {
let (transport, _our_addr) = make_transport();
// Start a listener so the peer is "reachable"
let peer_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let peer_addr = peer_listener.local_addr().unwrap();
let peer = make_tcp_peer(peer_addr);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert stale handle — simulates a dead writer task
insert_stale_handle(&transport, iid);
assert!(transport.connections.contains_key(&iid));
// check_health should remove the stale entry and verify the peer is reachable
let result = transport.check_health(iid, Duration::from_secs(2)).await;
// Stale entry should be gone
assert!(!transport.connections.contains_key(&iid));
// Since there WAS a previous connection entry, check_health returns Ok
// (the peer is reachable via our test listener)
assert!(result.is_ok());
}
#[tokio::test]
async fn test_writer_task_cleans_up_on_write_error() {
// Bind a listener, accept once, then drop everything to cause a write error
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let iid = crate::InstanceId::new_v4();
let (tx, rx) = flume::bounded::<SendTask>(8);
let connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>> =
Arc::new(DashMap::new());
connections.insert(iid, ConnectionHandle { tx: tx.clone() });
let conns = Arc::clone(&connections);
let cancel = CancellationToken::new();
// Spawn the writer task
let writer = tokio::spawn(connection_writer_task(addr, iid, rx, conns, cancel));
// Accept the connection, then immediately drop it + the listener
let (stream, _) = listener.accept().await.unwrap();
drop(stream);
drop(listener);
// Send a message — the writer should hit a broken-pipe error
tx.send(SendTask {
msg_type: MessageType::Message,
header: Bytes::from_static(b"hdr"),
payload: Bytes::from_static(b"pay"),
on_error: Arc::new(NullErrorHandler),
})
.unwrap();
// Wait for writer task to finish
let _ = writer.await;
// The writer should have removed the stale entry from the map
assert!(
!connections.contains_key(&iid),
"writer task should clean up its DashMap entry on write error"
);
}
#[tokio::test]
async fn test_send_message_does_not_fail_on_stale_handle() {
let (transport, _our_addr) = make_transport();
// Start a listener that accepts connections (simulates a healthy peer)
let peer_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let peer_addr = peer_listener.local_addr().unwrap();
let peer = make_tcp_peer(peer_addr);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert a stale handle
insert_stale_handle(&transport, iid);
// send_message should detect the stale handle and create a new one,
// NOT immediately call on_error
let error_handler = Arc::new(TrackingErrorHandler::new());
transport.send_message(
iid,
b"test-header".to_vec(),
b"test-payload".to_vec(),
MessageType::Message,
error_handler.clone(),
);
// Accept the connection that the new writer task will establish
let (mut stream, _) = peer_listener.accept().await.unwrap();
// Read the framed message from the stream to confirm delivery
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 256];
// Give the async writer a moment to flush the frame
let n = tokio::time::timeout(Duration::from_secs(2), stream.read(&mut buf))
.await
.expect("timed out waiting for data")
.expect("read error");
assert!(n > 0, "expected data from the writer task");
// No errors should have been reported
assert_eq!(
error_handler.error_count(),
0,
"send_message should retry on stale handle, not fail"
);
// The connections map should now contain a live handle
let entry = transport.connections.get(&iid).unwrap();
assert!(
!entry.tx.is_disconnected(),
"stale handle should have been replaced with a live one"
);
}
#[tokio::test]
async fn test_writer_task_drains_on_connect_failure() {
// Use an address where nothing is listening so connect will fail.
// Binding then immediately dropping gives us a port that is guaranteed closed.
let tmp = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = tmp.local_addr().unwrap();
drop(tmp);
let iid = crate::InstanceId::new_v4();
let (tx, rx) = flume::bounded::<SendTask>(8);
let connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>> =
Arc::new(DashMap::new());
connections.insert(iid, ConnectionHandle { tx: tx.clone() });
// Queue a message *before* the writer task even starts — this simulates
// the race between create_connection returning and connect completing.
let error_handler = Arc::new(TrackingErrorHandler::new());
tx.send(SendTask {
msg_type: MessageType::Message,
header: Bytes::from_static(b"hdr"),
payload: Bytes::from_static(b"pay"),
on_error: error_handler.clone(),
})
.unwrap();
let conns = Arc::clone(&connections);
let cancel = CancellationToken::new();
let writer = tokio::spawn(connection_writer_task(addr, iid, rx, conns, cancel));
let _ = writer.await;
assert_eq!(
error_handler.error_count(),
1,
"queued message should have its on_error called when connect fails"
);
assert!(
!connections.contains_key(&iid),
"writer task should clean up its DashMap entry on connect failure"
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use bytes::Bytes;
use futures::future::BoxFuture;
use crate::{InstanceId, PeerInfo, TransportKey, WorkerAddress};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::{sync::Arc, time::Duration};
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
/// Errors returned by individual [`Transport`] implementations.
#[derive(thiserror::Error, Debug)]
pub enum TransportError {
/// The peer's [`WorkerAddress`] does not contain an entry for this transport.
#[error("No endpoint found for transport")]
NoEndpoint,
/// The endpoint string could not be parsed (malformed URL, invalid address).
#[error("Invalid endpoint format")]
InvalidEndpoint,
/// The target peer was never registered with this transport.
#[error("Peer not registered: {0}")]
PeerNotRegistered(InstanceId),
/// The transport has not been started yet (no runtime handle).
#[error("Transport not started")]
NotStarted,
/// No responders available for the peer (e.g. NATS request with no subscriber).
#[error("No responders for peer")]
NoResponders,
}
/// Error type specific to health check operations
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum HealthCheckError {
/// The peer was never registered with this transport.
#[error("Peer not registered with transport")]
PeerNotRegistered,
/// The transport has not been started yet.
#[error("Transport not started")]
TransportNotStarted,
/// The peer is registered but no connection has ever been established.
#[error("Connection never established to peer")]
NeverConnected,
/// An existing connection is unhealthy or the peer is unreachable.
#[error("Connection failed or peer unreachable")]
ConnectionFailed,
/// The health check exceeded the specified timeout.
#[error("Health check timed out")]
Timeout,
}
/// Shared shutdown coordinator for graceful multi-phase shutdown.
///
/// **Phases**:
/// 1. **Gate** — `begin_drain()` flips the draining flag; transports reject new inbound requests.
/// 2. **Drain** — `wait_for_drain()` blocks until all in-flight guards are dropped.
/// 3. **Teardown** — `teardown_token().cancel()` kills listeners and writer tasks.
///
/// Hot-path cost: a single `AtomicBool::load(Relaxed)` per frame to check `is_draining()`.
#[derive(Clone)]
pub struct ShutdownState {
inner: Arc<ShutdownStateInner>,
}
struct ShutdownStateInner {
draining: AtomicBool,
in_flight: AtomicUsize,
drain_complete: Notify,
teardown_token: CancellationToken,
}
impl ShutdownState {
/// Create a new shutdown state. Not draining, zero in-flight.
pub fn new() -> Self {
Self {
inner: Arc::new(ShutdownStateInner {
draining: AtomicBool::new(false),
in_flight: AtomicUsize::new(0),
drain_complete: Notify::new(),
teardown_token: CancellationToken::new(),
}),
}
}
/// Returns `true` if drain has been initiated (Phase 1).
///
/// Uses `Relaxed` ordering — safe for the hot-path gate check because
/// the flag is monotonic (false → true, never reset).
#[inline]
pub fn is_draining(&self) -> bool {
self.inner.draining.load(Ordering::Relaxed)
}
/// Begin Phase 1: flip the draining flag. Idempotent.
pub fn begin_drain(&self) {
self.inner.draining.store(true, Ordering::Release);
}
/// Acquire an in-flight guard. The guard increments the counter on creation
/// and decrements it on drop. Use this to track requests that are being processed.
///
/// Guards are still acquirable after `begin_drain()` — this is intentional
/// so that already-accepted work can be tracked.
pub fn acquire(&self) -> InFlightGuard {
self.inner.in_flight.fetch_add(1, Ordering::AcqRel);
InFlightGuard {
inner: self.inner.clone(),
}
}
/// Current number of in-flight requests. Primarily for testing/debugging.
pub fn in_flight_count(&self) -> usize {
self.inner.in_flight.load(Ordering::Acquire)
}
/// Wait until in-flight count reaches zero. Returns immediately if already zero.
pub async fn wait_for_drain(&self) {
loop {
if self.inner.in_flight.load(Ordering::Acquire) == 0 {
return;
}
self.inner.drain_complete.notified().await;
}
}
/// Get the Phase 3 teardown token. Cancel this to kill listeners/writers.
pub fn teardown_token(&self) -> &CancellationToken {
&self.inner.teardown_token
}
}
impl Default for ShutdownState {
fn default() -> Self {
Self::new()
}
}
/// RAII guard that decrements the in-flight counter on drop.
pub struct InFlightGuard {
inner: Arc<ShutdownStateInner>,
}
impl InFlightGuard {
/// Explicitly complete this guard (equivalent to dropping it).
pub fn complete(self) {
// Drop impl handles the decrement
}
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
let prev = self.inner.in_flight.fetch_sub(1, Ordering::AcqRel);
// If we just decremented to 0, notify waiters
if prev == 1 {
self.inner.drain_complete.notify_waiters();
}
}
}
/// Policy for how long to wait during the drain phase.
#[derive(Debug, Clone)]
pub enum ShutdownPolicy {
/// Wait indefinitely for all in-flight requests to complete.
WaitForever,
/// Wait up to the given duration, then force teardown.
Timeout(Duration),
}
/// Abstraction over a single message transport (TCP, HTTP, NATS, gRPC, UCX).
///
/// Implementations handle peer registration, message sending, listener lifecycle,
/// health checking, and graceful shutdown. The trait is object-safe so transports
/// can be stored as `Arc<dyn Transport>`.
pub trait Transport: Send + Sync {
/// Unique key identifying this transport (e.g. `"tcp"`, `"grpc"`).
fn key(&self) -> TransportKey;
/// The [`WorkerAddress`] fragment advertised by this transport.
fn address(&self) -> WorkerAddress;
/// Register a remote peer, extracting its endpoint from [`PeerInfo`].
fn register(&self, peer_info: PeerInfo) -> Result<(), TransportError>;
/// Sends an active message to the remote instance
fn send_message(
&self,
instance_id: InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: Arc<dyn TransportErrorHandler>,
);
/// Start the transport (bind listener, spawn tasks) for the given instance.
fn start(
&self,
instance_id: InstanceId,
channels: TransportAdapter,
rt: tokio::runtime::Handle,
) -> BoxFuture<'_, anyhow::Result<()>>;
/// Tear down the transport, cancelling all tasks and closing connections.
fn shutdown(&self);
/// Begin draining: reject new inbound requests while allowing responses.
///
/// Default implementation is a no-op. Transports that need per-frame
/// gating (e.g., unsubscribing from NATS subjects) should override this.
fn begin_drain(&self) {}
/// Check if a registered peer is reachable and healthy
///
/// Returns Ok(()) if peer responds to health check within timeout.
/// Different transports implement this differently:
/// - NATS: request/reply to health subject
/// - TCP: check existing connection or attempt new connection
/// - HTTP: HEAD request to health endpoint
/// - UCX: endpoint status check
///
/// # Errors
/// - `PeerNotRegistered`: Peer was never registered with this transport
/// - `TransportNotStarted`: Transport hasn't been started yet
/// - `NeverConnected`: Peer is registered but no connection has been established
/// - `ConnectionFailed`: Connection exists/existed but is currently unhealthy or unreachable
/// - `Timeout`: Health check took longer than the specified timeout
fn check_health(
&self,
instance_id: InstanceId,
timeout: Duration,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), HealthCheckError>> + Send + '_>,
>;
}
/// Callback trait invoked when a transport fails to deliver a message.
///
/// The original `header` and `payload` are returned so higher layers can
/// retry or log the failure.
pub trait TransportErrorHandler: Send + Sync {
/// Called when message delivery fails. Receives the original data and error description.
fn on_error(&self, header: Bytes, payload: Bytes, error: String);
}
/// Message type discriminator for routing frames to appropriate streams
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageType {
#[allow(missing_docs)]
Message = 0,
#[allow(missing_docs)]
Response = 1,
#[allow(missing_docs)]
Ack = 2,
#[allow(missing_docs)]
Event = 3,
/// Sent back to a peer when we are draining and cannot accept new messages.
/// The original request header is echoed back for correlation.
ShuttingDown = 4,
}
impl MessageType {
/// Try to convert a u8 to a MessageType
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(MessageType::Message),
1 => Some(MessageType::Response),
2 => Some(MessageType::Ack),
3 => Some(MessageType::Event),
4 => Some(MessageType::ShuttingDown),
_ => None,
}
}
/// Convert MessageType to u8
pub fn as_u8(self) -> u8 {
self as u8
}
}
/// Sender-side handle given to transports for routing inbound frames.
///
/// Each transport receives a clone of this adapter during [`Transport::start`]
/// and uses it to forward decoded `(header, payload)` pairs to the appropriate
/// stream based on [`MessageType`].
#[derive(Clone)]
pub struct TransportAdapter {
/// Channel for inbound [`MessageType::Message`] frames.
pub message_stream: flume::Sender<(Bytes, Bytes)>,
/// Channel for inbound [`MessageType::Response`] and [`MessageType::ShuttingDown`] frames.
pub response_stream: flume::Sender<(Bytes, Bytes)>,
/// Channel for inbound [`MessageType::Ack`] and [`MessageType::Event`] frames.
pub event_stream: flume::Sender<(Bytes, Bytes)>,
/// Shared shutdown coordinator for drain-aware routing.
pub shutdown_state: ShutdownState,
}
/// Receiver-side handle for consuming inbound frames from all transports.
///
/// Returned by [`make_channels`] alongside the corresponding [`TransportAdapter`].
/// Higher layers pull `(header, payload)` pairs from these channels.
pub struct DataStreams {
/// Receiver for inbound message frames.
pub message_stream: flume::Receiver<(Bytes, Bytes)>,
/// Receiver for inbound response and shutting-down frames.
pub response_stream: flume::Receiver<(Bytes, Bytes)>,
/// Receiver for inbound ack and event frames.
pub event_stream: flume::Receiver<(Bytes, Bytes)>,
/// Shared shutdown coordinator.
pub shutdown_state: ShutdownState,
}
type DataStreamTuple = (
flume::Receiver<(Bytes, Bytes)>,
flume::Receiver<(Bytes, Bytes)>,
flume::Receiver<(Bytes, Bytes)>,
);
impl DataStreams {
/// Destructure into the three raw receivers `(message, response, event)`.
pub fn into_parts(self) -> DataStreamTuple {
(self.message_stream, self.response_stream, self.event_stream)
}
/// Receive a message with an in-flight guard for drain tracking.
///
/// Returns `(header, payload, guard)`. The guard keeps the in-flight counter
/// incremented until it is dropped or `complete()` is called.
pub async fn recv_message_tracked(
&self,
) -> Result<(Bytes, Bytes, InFlightGuard), flume::RecvError> {
let (header, payload) = self.message_stream.recv_async().await?;
let guard = self.shutdown_state.acquire();
Ok((header, payload, guard))
}
}
/// Create a matched pair of [`TransportAdapter`] (sender) and [`DataStreams`] (receiver).
///
/// Both sides share the same [`ShutdownState`] so drain coordination is automatic.
pub fn make_channels() -> (TransportAdapter, DataStreams) {
let shutdown_state = ShutdownState::new();
let (message_tx, message_rx) = flume::unbounded();
let (response_tx, response_rx) = flume::unbounded();
let (event_tx, event_rx) = flume::unbounded();
(
TransportAdapter {
message_stream: message_tx,
response_stream: response_tx,
event_stream: event_tx,
shutdown_state: shutdown_state.clone(),
},
DataStreams {
message_stream: message_rx,
response_stream: response_rx,
event_stream: event_rx,
shutdown_state,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, timeout};
#[test]
fn test_shutdown_state_initial() {
let state = ShutdownState::new();
assert!(!state.is_draining());
assert_eq!(state.in_flight_count(), 0);
}
#[test]
fn test_begin_drain_flips_flag() {
let state = ShutdownState::new();
state.begin_drain();
assert!(state.is_draining());
}
#[test]
fn test_begin_drain_idempotent() {
let state = ShutdownState::new();
state.begin_drain();
state.begin_drain();
assert!(state.is_draining());
}
#[test]
fn test_acquire_increments_inflight() {
let state = ShutdownState::new();
let _g1 = state.acquire();
assert_eq!(state.in_flight_count(), 1);
let _g2 = state.acquire();
assert_eq!(state.in_flight_count(), 2);
}
#[test]
fn test_guard_drop_decrements_inflight() {
let state = ShutdownState::new();
let g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
drop(g);
assert_eq!(state.in_flight_count(), 0);
}
#[test]
fn test_guard_complete_decrements() {
let state = ShutdownState::new();
let g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
g.complete();
assert_eq!(state.in_flight_count(), 0);
}
#[tokio::test]
async fn test_wait_for_drain_immediate() {
let state = ShutdownState::new();
// Should complete immediately since in_flight is 0
timeout(Duration::from_millis(100), state.wait_for_drain())
.await
.expect("wait_for_drain should complete immediately when in_flight is 0");
}
#[tokio::test]
async fn test_wait_for_drain_blocks_then_completes() {
let state = ShutdownState::new();
let guard = state.acquire();
let state_clone = state.clone();
let handle = tokio::spawn(async move {
state_clone.wait_for_drain().await;
});
// Give the waiter time to park
sleep(Duration::from_millis(50)).await;
assert!(!handle.is_finished());
// Drop guard → should unblock
drop(guard);
timeout(Duration::from_millis(100), handle)
.await
.expect("should complete after guard drop")
.unwrap();
}
#[tokio::test]
async fn test_multiple_guards_concurrent() {
let state = ShutdownState::new();
let guards: Vec<_> = (0..10).map(|_| state.acquire()).collect();
assert_eq!(state.in_flight_count(), 10);
let state_clone = state.clone();
let handle = tokio::spawn(async move {
state_clone.wait_for_drain().await;
});
// Drop all guards
drop(guards);
timeout(Duration::from_millis(100), handle)
.await
.expect("should complete after all guards drop")
.unwrap();
assert_eq!(state.in_flight_count(), 0);
}
#[tokio::test]
async fn test_drain_with_zero_inflight() {
let state = ShutdownState::new();
state.begin_drain();
// Should complete immediately
timeout(Duration::from_millis(100), state.wait_for_drain())
.await
.expect("should complete immediately with zero in-flight");
}
#[test]
fn test_acquire_works_after_drain() {
let state = ShutdownState::new();
state.begin_drain();
let _g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
}
#[test]
fn test_guard_drop_during_panic() {
let state = ShutdownState::new();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
panic!("intentional panic");
}));
assert!(result.is_err());
// Guard's Drop should have fired even during unwind
assert_eq!(state.in_flight_count(), 0);
}
#[test]
fn test_shutting_down_from_u8() {
assert_eq!(MessageType::from_u8(4), Some(MessageType::ShuttingDown));
}
#[test]
fn test_shutting_down_as_u8() {
assert_eq!(MessageType::ShuttingDown.as_u8(), 4);
}
#[test]
fn test_unknown_message_type_still_none() {
assert_eq!(MessageType::from_u8(5), None);
assert_eq!(MessageType::from_u8(255), None);
}
#[test]
fn test_make_channels_includes_shutdown_state() {
let (adapter, streams) = make_channels();
// Both sides should share the same ShutdownState (via Arc)
assert!(!adapter.shutdown_state.is_draining());
assert!(!streams.shutdown_state.is_draining());
// Mutating one should be visible through the other
adapter.shutdown_state.begin_drain();
assert!(streams.shutdown_state.is_draining());
}
#[tokio::test]
async fn test_recv_message_tracked_returns_guard() {
let (adapter, streams) = make_channels();
// Send a message through the adapter
adapter
.message_stream
.send_async((
bytes::Bytes::from_static(b"hdr"),
bytes::Bytes::from_static(b"pay"),
))
.await
.unwrap();
// Receive with tracking
let (header, payload, guard) = streams.recv_message_tracked().await.unwrap();
assert_eq!(&header[..], b"hdr");
assert_eq!(&payload[..], b"pay");
assert_eq!(streams.shutdown_state.in_flight_count(), 1);
// Drop guard
drop(guard);
assert_eq!(streams.shutdown_state.in_flight_count(), 0);
}
#[test]
fn test_shutdown_state_clone_shares_inner() {
let s1 = ShutdownState::new();
let s2 = s1.clone();
s1.begin_drain();
assert!(s2.is_draining());
let _g = s1.acquire();
assert_eq!(s2.in_flight_count(), 1);
}
#[test]
fn test_teardown_token() {
let state = ShutdownState::new();
assert!(!state.teardown_token().is_cancelled());
state.teardown_token().cancel();
assert!(state.teardown_token().is_cancelled());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Common test utilities for transport integration tests
//!
//! This module provides a transport-agnostic test infrastructure that can be reused
//! across different transport implementations (TCP, RDMA, UDP, UDS, etc.).
#![allow(dead_code)]
// #[cfg(feature = "grpc")]
// use velo_transports::grpc::{GrpcTransport, GrpcTransportBuilder};
// #[cfg(feature = "http")]
// use velo_transports::http::{HttpTransport, HttpTransportBuilder};
// #[cfg(feature = "nats")]
// use velo_transports::nats::{NatsTransport, NatsTransportBuilder};
// #[cfg(feature = "ucx")]
// use velo_transports::ucx::{UcxTransport, UcxTransportBuilder};
use bytes::Bytes;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::timeout;
use velo_transports::{
DataStreams, InstanceId, MessageType, PeerInfo, Transport, TransportErrorHandler,
tcp::{TcpTransport, TcpTransportBuilder},
};
use std::sync::Once;
use tracing_subscriber::FmtSubscriber;
#[allow(dead_code)]
static INIT: Once = Once::new();
#[allow(dead_code)]
pub fn init_tracing() {
INIT.call_once(|| {
let _ = FmtSubscriber::builder()
.with_env_filter("trace") // or "info"
.try_init();
});
}
pub mod scenarios;
/// Test error handler that tracks errors for verification
#[derive(Clone)]
pub struct TestErrorHandler {
errors: Arc<Mutex<Vec<(Bytes, Bytes, String)>>>,
}
impl TestErrorHandler {
pub fn new() -> Self {
Self {
errors: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn get_errors(&self) -> Vec<(Bytes, Bytes, String)> {
self.errors.lock().unwrap().clone()
}
pub fn error_count(&self) -> usize {
self.errors.lock().unwrap().len()
}
pub fn clear(&self) {
self.errors.lock().unwrap().clear();
}
}
impl TransportErrorHandler for TestErrorHandler {
fn on_error(&self, header: Bytes, payload: Bytes, error: String) {
self.errors.lock().unwrap().push((header, payload, error));
}
}
/// Handle to a transport instance with its streams for testing
///
/// This is a generic test handle that works with any transport implementation.
/// Use `TestTransportHandle::with_factory()` to create instances with custom transports,
/// or use convenience methods like `TestTransportHandle::new()` for TCP transport.
pub struct TestTransportHandle<T: Transport> {
pub transport: T,
pub streams: DataStreams,
pub instance_id: InstanceId,
pub error_handler: Arc<TestErrorHandler>,
runtime: tokio::runtime::Handle,
}
impl<T: Transport> TestTransportHandle<T> {
/// Create a new test transport using a factory function
///
/// This is the generic constructor that works with any transport implementation.
/// The factory function should create and return a transport instance.
///
/// # Example
/// ```ignore
/// let handle = TestTransportHandle::with_factory(|| {
/// MyTransportBuilder::new().build()
/// }).await?;
/// ```
pub async fn with_factory<F>(factory: F) -> anyhow::Result<Self>
where
F: FnOnce() -> anyhow::Result<T>,
{
let transport = factory()?;
let instance_id = InstanceId::new_v4();
let error_handler = Arc::new(TestErrorHandler::new());
// Create channels for this transport
let (adapter, streams) = velo_transports::make_channels();
// Get runtime handle
let runtime = tokio::runtime::Handle::current();
// Start the transport
transport
.start(instance_id, adapter, runtime.clone())
.await?;
// Give the listener a moment to bind and start accepting connections
tokio::time::sleep(Duration::from_millis(50)).await;
Ok(Self {
transport,
streams,
instance_id,
error_handler,
runtime,
})
}
/// Register another transport as a peer
pub fn register_peer<U: Transport>(
&self,
other: &TestTransportHandle<U>,
) -> anyhow::Result<()> {
let peer_info = PeerInfo::new(other.instance_id, other.transport.address());
self.transport
.register(peer_info)
.map_err(|e| anyhow::anyhow!("Failed to register peer: {:?}", e))?;
Ok(())
}
/// Send a message to a peer
pub fn send(
&self,
target: InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
msg_type: MessageType,
) {
self.transport.send_message(
target,
header,
payload,
msg_type,
self.error_handler.clone(),
);
}
/// Receive a message with timeout
pub async fn recv_message(&self, timeout_duration: Duration) -> anyhow::Result<(Bytes, Bytes)> {
timeout(timeout_duration, self.streams.message_stream.recv_async())
.await
.map_err(|_| anyhow::anyhow!("Timeout waiting for message"))?
.map_err(|e| anyhow::anyhow!("Channel error: {}", e))
}
/// Receive a response with timeout
pub async fn recv_response(
&self,
timeout_duration: Duration,
) -> anyhow::Result<(Bytes, Bytes)> {
timeout(timeout_duration, self.streams.response_stream.recv_async())
.await
.map_err(|_| anyhow::anyhow!("Timeout waiting for response"))?
.map_err(|e| anyhow::anyhow!("Channel error: {}", e))
}
/// Receive an event with timeout
pub async fn recv_event(&self, timeout_duration: Duration) -> anyhow::Result<(Bytes, Bytes)> {
timeout(timeout_duration, self.streams.event_stream.recv_async())
.await
.map_err(|_| anyhow::anyhow!("Timeout waiting for event"))?
.map_err(|e| anyhow::anyhow!("Channel error: {}", e))
}
/// Collect multiple messages with timeout
pub async fn collect_messages(
&self,
count: usize,
timeout_duration: Duration,
) -> anyhow::Result<Vec<(Bytes, Bytes)>> {
let mut messages = Vec::new();
for _ in 0..count {
messages.push(self.recv_message(timeout_duration).await?);
}
Ok(messages)
}
/// Collect multiple messages with timeout, sorted by header for order-independent comparison
///
/// This is useful for testing transports that don't guarantee delivery order (e.g., HTTP).
/// Messages are sorted by header bytes to enable deterministic comparison regardless of
/// delivery order.
pub async fn collect_messages_unordered(
&self,
count: usize,
timeout_duration: Duration,
) -> anyhow::Result<Vec<(Bytes, Bytes)>> {
let mut messages = self.collect_messages(count, timeout_duration).await?;
messages.sort_by(|a, b| a.0.cmp(&b.0));
Ok(messages)
}
/// Collect multiple responses with timeout
pub async fn collect_responses(
&self,
count: usize,
timeout_duration: Duration,
) -> anyhow::Result<Vec<(Bytes, Bytes)>> {
let mut responses = Vec::new();
for _ in 0..count {
responses.push(self.recv_response(timeout_duration).await?);
}
Ok(responses)
}
/// Shutdown the transport
pub fn shutdown(self) {
self.transport.shutdown();
}
}
// TCP-specific convenience constructors
impl TestTransportHandle<TcpTransport> {
/// Create a new TCP transport on a random available port
///
/// This is a convenience method for creating TCP transports.
/// For other transport types, use `with_factory()`.
pub async fn new_tcp() -> anyhow::Result<Self> {
Self::with_factory(|| {
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
TcpTransportBuilder::new().from_listener(listener)?.build()
})
.await
}
/// Alias for `new_tcp()` to maintain backward compatibility
pub async fn new() -> anyhow::Result<Self> {
Self::new_tcp().await
}
}
// // UCX-specific convenience constructors
// #[cfg(feature = "ucx")]
// impl TestTransportHandle<UcxTransport> {
// /// Create a new UCX transport
// ///
// /// This is a convenience method for creating UCX transports.
// /// For other transport types, use `with_factory()`.
// pub async fn new_ucx() -> anyhow::Result<Self> {
// Self::with_factory(|| UcxTransportBuilder::new().build()).await
// }
// }
// // HTTP-specific convenience constructors
// #[cfg(feature = "http")]
// impl TestTransportHandle<HttpTransport> {
// /// Create a new HTTP transport with OS-provided port
// ///
// /// This is a convenience method for creating HTTP transports.
// /// For other transport types, use `with_factory()`.
// pub async fn new_http() -> anyhow::Result<Self> {
// Self::with_factory(|| {
// // Use default builder which binds to 0.0.0.0:0 (OS-provided port)
// HttpTransportBuilder::new().build()
// })
// .await
// }
// }
// // NATS-specific convenience constructor
// #[cfg(feature = "nats")]
// impl TestTransportHandle<NatsTransport> {
// /// Create a new NATS transport
// ///
// /// This is a convenience method for creating NATS transports.
// /// For other transport types, use `with_factory()`.
// ///
// /// Note: NATS transport requires special handling because it needs the instance_id
// /// at construction time to set up subject subscriptions. We can't use the generic
// /// with_factory() because it creates the instance_id AFTER calling the factory.
// pub async fn new_nats() -> anyhow::Result<Self> {
// // Create instance_id
// let instance_id = InstanceId::new_v4();
// let error_handler = Arc::new(TestErrorHandler::new());
// // Build transport
// let transport = NatsTransportBuilder::new()
// .nats_url("nats://127.0.0.1:4222")
// .build()?;
// // Create channels for this transport
// let (adapter, streams) = velo_transports::make_channels();
// // Get runtime handle
// let runtime = tokio::runtime::Handle::current();
// // Start the transport
// transport
// .start(instance_id, adapter, runtime.clone())
// .await?;
// // Give NATS a moment to establish subscriptions
// tokio::time::sleep(Duration::from_millis(50)).await;
// Ok(Self {
// transport,
// streams,
// instance_id,
// error_handler,
// runtime,
// })
// }
// }
// // gRPC-specific convenience constructors
// #[cfg(feature = "grpc")]
// impl TestTransportHandle<GrpcTransport> {
// /// Create a new gRPC transport with OS-provided port
// ///
// /// This is a convenience method for creating gRPC transports.
// /// For other transport types, use `with_factory()`.
// pub async fn new_grpc() -> anyhow::Result<Self> {
// Self::with_factory(|| {
// // Use default builder which binds to 0.0.0.0:0 (OS-provided port)
// GrpcTransportBuilder::new().build()
// })
// .await
// }
// }
/// Multi-transport test cluster
///
/// A generic cluster that works with any transport implementation.
/// All transports in the cluster are registered with each other in a full mesh topology.
pub struct TestCluster<T: Transport> {
transports: Vec<TestTransportHandle<T>>,
}
impl<T: Transport> TestCluster<T> {
/// Create a new test cluster using a factory function
///
/// This is the generic constructor that works with any transport implementation.
/// The factory function will be called `size` times to create each transport.
///
/// # Example
/// ```ignore
/// let cluster = TestCluster::with_factory(3, || {
/// MyTransportBuilder::new().build()
/// }).await?;
/// ```
pub async fn with_factory<F>(size: usize, factory: F) -> anyhow::Result<Self>
where
F: Fn() -> anyhow::Result<T>,
{
let mut transports = Vec::new();
for _ in 0..size {
transports.push(TestTransportHandle::with_factory(&factory).await?);
}
// Register all peers with each other (full mesh)
for i in 0..transports.len() {
for j in 0..transports.len() {
if i != j {
transports[i].register_peer(&transports[j])?;
}
}
}
Ok(Self { transports })
}
/// Get a transport by index
pub fn get(&self, index: usize) -> &TestTransportHandle<T> {
&self.transports[index]
}
/// Get all transports
pub fn all(&self) -> &[TestTransportHandle<T>] {
&self.transports
}
/// Shutdown all transports
pub fn shutdown(self) {
for transport in self.transports {
transport.shutdown();
}
}
}
// TCP-specific convenience constructor
impl TestCluster<TcpTransport> {
/// Create a new TCP test cluster with the specified number of transports
///
/// This is a convenience method for creating TCP clusters.
/// For other transport types, use `with_factory()`.
pub async fn new(size: usize) -> anyhow::Result<Self> {
Self::with_factory(size, || {
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
TcpTransportBuilder::new().from_listener(listener)?.build()
})
.await
}
}
// UCX-specific convenience constructor
#[cfg(feature = "ucx")]
impl TestCluster<UcxTransport> {
/// Create a new UCX test cluster with the specified number of transports
///
/// This is a convenience method for creating UCX clusters.
/// For other transport types, use `with_factory()`.
pub async fn new_ucx(size: usize) -> anyhow::Result<Self> {
Self::with_factory(size, || UcxTransportBuilder::new().build()).await
}
}
// // HTTP-specific convenience constructor
// #[cfg(feature = "http")]
// impl TestCluster<HttpTransport> {
// /// Create a new HTTP test cluster with the specified number of transports
// ///
// /// This is a convenience method for creating HTTP clusters.
// /// For other transport types, use `with_factory()`.
// pub async fn new_http(size: usize) -> anyhow::Result<Self> {
// Self::with_factory(size, || {
// // Use default builder which binds to OS-provided ports
// HttpTransportBuilder::new().build()
// })
// .await
// }
// }
// // NATS-specific convenience constructor
// #[cfg(feature = "nats")]
// impl TestCluster<NatsTransport> {
// /// Create a new NATS test cluster with the specified number of transports
// ///
// /// This is a convenience method for creating NATS clusters.
// /// For other transport types, use `with_factory()`.
// ///
// /// Note: NATS transport requires special handling because it needs the instance_id
// /// at construction time. We can't use the generic with_factory() which creates
// /// instance_id after calling the factory function.
// pub async fn new_nats(size: usize) -> anyhow::Result<Self> {
// let mut transports = Vec::new();
// for _ in 0..size {
// transports.push(TestTransportHandle::new_nats().await?);
// }
// // Register all peers with each other (full mesh)
// for i in 0..transports.len() {
// for j in 0..transports.len() {
// if i != j {
// transports[i].register_peer(&transports[j])?;
// }
// }
// }
// Ok(Self { transports })
// }
// }
// // gRPC-specific convenience constructor
// #[cfg(feature = "grpc")]
// impl TestCluster<GrpcTransport> {
// /// Create a new gRPC test cluster with the specified number of transports
// ///
// /// This is a convenience method for creating gRPC clusters.
// /// For other transport types, use `with_factory()`.
// pub async fn new_grpc(size: usize) -> anyhow::Result<Self> {
// Self::with_factory(size, || {
// // Use default builder which binds to OS-provided ports
// GrpcTransportBuilder::new().build()
// })
// .await
// }
// }
// Helper utilities
/// Get a random available port
pub fn get_random_port() -> u16 {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
listener.local_addr().unwrap().port()
}
/// Create test data with the specified size
pub fn test_data(size: usize) -> Vec<u8> {
(0..size).map(|i| (i % 256) as u8).collect()
}
/// Create a test message with predictable content
pub fn test_message(id: u32) -> (Vec<u8>, Vec<u8>) {
let header = format!("header-{}", id).into_bytes();
let payload = format!("payload-{}", id).into_bytes();
(header, payload)
}
/// Assert that a received message matches expected values
pub fn assert_message_eq(
received: (Bytes, Bytes),
expected_header: &[u8],
expected_payload: &[u8],
) {
assert_eq!(received.0.as_ref(), expected_header, "Header mismatch");
assert_eq!(received.1.as_ref(), expected_payload, "Payload mismatch");
}
// Transport factory abstraction for parameterized tests
/// Transport factory trait for creating transports in parameterized tests
pub trait TransportFactory {
type Transport: Transport;
async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>>;
async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>>;
}
/// TCP transport factory
pub struct TcpFactory;
impl TransportFactory for TcpFactory {
type Transport = TcpTransport;
async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
TestTransportHandle::new_tcp().await
}
async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
TestCluster::new(size).await
}
}
// /// UCX transport factory
// #[cfg(feature = "ucx")]
// pub struct UcxFactory;
// #[cfg(feature = "ucx")]
// impl TransportFactory for UcxFactory {
// type Transport = UcxTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_ucx().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_ucx(size).await
// }
// }
// /// HTTP transport factory
// #[cfg(feature = "http")]
// pub struct HttpFactory;
// #[cfg(feature = "http")]
// impl TransportFactory for HttpFactory {
// type Transport = HttpTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_http().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_http(size).await
// }
// }
// /// NATS transport factory
// #[cfg(feature = "nats")]
// pub struct NatsFactory;
// #[cfg(feature = "nats")]
// impl TransportFactory for NatsFactory {
// type Transport = NatsTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_nats().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_nats(size).await
// }
// }
// /// gRPC transport factory
// #[cfg(feature = "grpc")]
// pub struct GrpcFactory;
// #[cfg(feature = "grpc")]
// impl TransportFactory for GrpcFactory {
// type Transport = GrpcTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_grpc().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_grpc(size).await
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Generic test scenarios that work with any transport implementation
use super::*;
use std::time::Duration;
const TEST_TIMEOUT: Duration = Duration::from_secs(5);
pub async fn single_message_round_trip<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn bidirectional_messaging<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
transport_b.register_peer(&transport_a).unwrap();
// A -> B
let (header1, payload1) = test_message(1);
transport_a.send(
transport_b.instance_id,
header1.clone(),
payload1.clone(),
MessageType::Message,
);
// B -> A
let (header2, payload2) = test_message(2);
transport_b.send(
transport_a.instance_id,
header2.clone(),
payload2.clone(),
MessageType::Message,
);
let recv_b = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
let recv_a = transport_a.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(recv_b, &header1, &payload1);
assert_message_eq(recv_a, &header2, &payload2);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn multiple_messages_same_connection<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Send 10 messages
for i in 0..10 {
let (header, payload) = test_message(i);
transport_a.send(
transport_b.instance_id,
header,
payload,
MessageType::Message,
);
}
// Receive and verify all messages (order-independent)
let messages = transport_b
.collect_messages_unordered(10, TEST_TIMEOUT)
.await
.unwrap();
// Generate expected messages and sort them the same way
let mut expected: Vec<_> = (0..10).map(test_message).collect();
expected.sort_by(|a, b| a.0.cmp(&b.0));
for (i, msg) in messages.iter().enumerate() {
assert_message_eq(msg.clone(), &expected[i].0, &expected[i].1);
}
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn response_message_type<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Response,
);
let received = transport_b.recv_response(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn event_message_type<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Event,
);
let received = transport_b.recv_event(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn ack_message_type<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Ack,
);
// Acks route to event stream
let received = transport_b.recv_event(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn mixed_message_types<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Send different message types
let (msg_h, msg_p) = test_message(1);
transport_a.send(
transport_b.instance_id,
msg_h.clone(),
msg_p.clone(),
MessageType::Message,
);
let (resp_h, resp_p) = test_message(2);
transport_a.send(
transport_b.instance_id,
resp_h.clone(),
resp_p.clone(),
MessageType::Response,
);
let (event_h, event_p) = test_message(3);
transport_a.send(
transport_b.instance_id,
event_h.clone(),
event_p.clone(),
MessageType::Event,
);
// Receive from appropriate streams
let recv_msg = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
let recv_resp = transport_b.recv_response(TEST_TIMEOUT).await.unwrap();
let recv_event = transport_b.recv_event(TEST_TIMEOUT).await.unwrap();
assert_message_eq(recv_msg, &msg_h, &msg_p);
assert_message_eq(recv_resp, &resp_h, &resp_p);
assert_message_eq(recv_event, &event_h, &event_p);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn large_payload<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// 1MB payload
let header = b"large-payload".to_vec();
let payload = test_data(1024 * 1024);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn empty_header_and_payload<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
transport_a.send(
transport_b.instance_id,
vec![],
vec![],
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &[], &[]);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn cluster_mesh_communication<F: TransportFactory>() {
let cluster = F::create_cluster(3).await.unwrap();
// Each node sends to every other node
for i in 0..3 {
for j in 0..3 {
if i != j {
let (header, payload) = test_message((i * 10 + j) as u32);
cluster.get(i).send(
cluster.get(j).instance_id,
header,
payload,
MessageType::Message,
);
}
}
}
// Each node should receive 2 messages
for i in 0..3 {
let messages = cluster
.get(i)
.collect_messages(2, TEST_TIMEOUT)
.await
.unwrap();
assert_eq!(messages.len(), 2);
}
cluster.shutdown();
}
pub async fn concurrent_senders<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Send from multiple tasks concurrently (without needing to move transport_a)
let target_id = transport_b.instance_id;
let mut handles = vec![];
for i in 0..10 {
let (header, payload) = test_message(i);
// Send directly without spawning - the send itself is non-blocking
transport_a.send(target_id, header, payload, MessageType::Message);
}
// Alternatively test with actual concurrent tasks using a different approach
// Spawn receiver tasks to demonstrate concurrent receives
for _ in 0..10 {
let handle = tokio::spawn(async {
// Just to demonstrate concurrency is working
tokio::time::sleep(Duration::from_micros(1)).await;
});
handles.push(handle);
}
// Wait for all tasks
for handle in handles {
handle.await.unwrap();
}
// Receive all messages
let messages = transport_b
.collect_messages(10, TEST_TIMEOUT)
.await
.unwrap();
assert_eq!(messages.len(), 10);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn send_to_unregistered_peer<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
// Don't register B with A
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
// Give it a moment to process
tokio::time::sleep(Duration::from_millis(100)).await;
// Should have an error
assert!(transport_a.error_handler.error_count() > 0);
let errors = transport_a.error_handler.get_errors();
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].0, header.as_slice());
assert_eq!(errors[0].1, payload.as_slice());
assert!(errors[0].2.to_lowercase().contains("peer not registered"));
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn connection_reuse<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// First message establishes connection
let (header1, payload1) = test_message(1);
transport_a.send(
transport_b.instance_id,
header1.clone(),
payload1.clone(),
MessageType::Message,
);
let recv1 = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(recv1, &header1, &payload1);
// Second message reuses connection
let (header2, payload2) = test_message(2);
transport_a.send(
transport_b.instance_id,
header2.clone(),
payload2.clone(),
MessageType::Message,
);
let recv2 = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(recv2, &header2, &payload2);
// No errors should have occurred
assert_eq!(transport_a.error_handler.error_count(), 0);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn graceful_shutdown<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Send a message
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
// Receive it
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
// Shutdown should complete without panics
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn high_throughput<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let num_messages = 100;
// Send many messages
for i in 0..num_messages {
let (header, payload) = test_message(i);
transport_a.send(
transport_b.instance_id,
header,
payload,
MessageType::Message,
);
}
// Receive all messages (order-independent)
let messages = transport_b
.collect_messages_unordered(num_messages as usize, TEST_TIMEOUT)
.await
.unwrap();
assert_eq!(messages.len(), num_messages as usize);
// Generate expected messages and sort them the same way
let mut expected: Vec<_> = (0..num_messages).map(test_message).collect();
expected.sort_by(|a, b| a.0.cmp(&b.0));
// Verify all messages received correctly
for (i, msg) in messages.iter().enumerate() {
assert_message_eq(msg.clone(), &expected[i].0, &expected[i].1);
}
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn zero_copy_efficiency<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Large payload to test zero-copy
let header = b"zero-copy-test".to_vec();
let payload = test_data(512 * 1024); // 512KB
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
// Verify no errors
assert_eq!(transport_a.error_handler.error_count(), 0);
transport_a.shutdown();
transport_b.shutdown();
}
// --- Drain / shutdown scenarios ---
/// After begin_drain on B, messages sent from A to B should NOT arrive on B's message_stream.
pub async fn drain_rejects_messages<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Begin drain on B (both transport-level and shutdown-state, mirroring VeloBackend::graceful_shutdown)
transport_b.transport.begin_drain();
transport_b.streams.shutdown_state.begin_drain();
tokio::time::sleep(Duration::from_millis(100)).await;
// A sends a Message to B
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header,
payload,
MessageType::Message,
);
// B's message_stream should be empty (message rejected during drain)
let result = tokio::time::timeout(
Duration::from_millis(500),
transport_b.streams.message_stream.recv_async(),
)
.await;
assert!(
result.is_err(),
"Expected timeout — messages should be rejected during drain"
);
transport_a.transport.shutdown();
transport_b.streams.shutdown_state.teardown_token().cancel();
transport_b.transport.shutdown();
}
/// After begin_drain on B, responses sent from A to B should still arrive on B's response_stream.
pub async fn drain_accepts_responses<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Begin drain on B
transport_b.transport.begin_drain();
transport_b.streams.shutdown_state.begin_drain();
tokio::time::sleep(Duration::from_millis(100)).await;
// A sends a Response to B
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Response,
);
// B's response_stream should still receive it
let received = transport_b.recv_response(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.transport.shutdown();
transport_b.streams.shutdown_state.teardown_token().cancel();
transport_b.transport.shutdown();
}
/// After begin_drain on B, events sent from A to B should still arrive on B's event_stream.
pub async fn drain_accepts_events<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Begin drain on B
transport_b.transport.begin_drain();
transport_b.streams.shutdown_state.begin_drain();
tokio::time::sleep(Duration::from_millis(100)).await;
// A sends an Event to B
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Event,
);
// B's event_stream should still receive it
let received = transport_b.recv_event(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.transport.shutdown();
transport_b.streams.shutdown_state.teardown_token().cancel();
transport_b.transport.shutdown();
}
/// After begin_drain on B, health checks from A to B should still succeed.
pub async fn health_during_drain<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Establish a connection first: send a message and receive it
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
// Begin drain on B
transport_b.transport.begin_drain();
transport_b.streams.shutdown_state.begin_drain();
tokio::time::sleep(Duration::from_millis(100)).await;
// A checks health of B — should still succeed during drain
let result = transport_a
.transport
.check_health(transport_b.instance_id, Duration::from_secs(2))
.await;
assert!(
result.is_ok(),
"Health check should succeed during drain: {:?}",
result.err()
);
transport_a.transport.shutdown();
transport_b.streams.shutdown_state.teardown_token().cancel();
transport_b.transport.shutdown();
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for TCP transport
mod common;
use common::{TcpFactory, scenarios};
#[tokio::test]
async fn test_single_message_round_trip() {
scenarios::single_message_round_trip::<TcpFactory>().await;
}
#[tokio::test]
async fn test_bidirectional_messaging() {
scenarios::bidirectional_messaging::<TcpFactory>().await;
}
#[tokio::test]
async fn test_multiple_messages_same_connection() {
scenarios::multiple_messages_same_connection::<TcpFactory>().await;
}
#[tokio::test]
async fn test_response_message_type() {
scenarios::response_message_type::<TcpFactory>().await;
}
#[tokio::test]
async fn test_event_message_type() {
scenarios::event_message_type::<TcpFactory>().await;
}
#[tokio::test]
async fn test_ack_message_type() {
scenarios::ack_message_type::<TcpFactory>().await;
}
#[tokio::test]
async fn test_mixed_message_types() {
scenarios::mixed_message_types::<TcpFactory>().await;
}
#[tokio::test]
async fn test_large_payload() {
scenarios::large_payload::<TcpFactory>().await;
}
#[tokio::test]
async fn test_empty_header_and_payload() {
scenarios::empty_header_and_payload::<TcpFactory>().await;
}
#[tokio::test]
async fn test_cluster_mesh_communication() {
scenarios::cluster_mesh_communication::<TcpFactory>().await;
}
#[tokio::test]
async fn test_concurrent_senders() {
scenarios::concurrent_senders::<TcpFactory>().await;
}
#[tokio::test]
async fn test_send_to_unregistered_peer() {
scenarios::send_to_unregistered_peer::<TcpFactory>().await;
}
#[tokio::test]
async fn test_connection_reuse() {
scenarios::connection_reuse::<TcpFactory>().await;
}
#[tokio::test]
async fn test_graceful_shutdown() {
scenarios::graceful_shutdown::<TcpFactory>().await;
}
#[tokio::test]
async fn test_high_throughput() {
scenarios::high_throughput::<TcpFactory>().await;
}
#[tokio::test]
async fn test_zero_copy_efficiency() {
scenarios::zero_copy_efficiency::<TcpFactory>().await;
}
#[tokio::test]
async fn test_drain_rejects_messages() {
scenarios::drain_rejects_messages::<TcpFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_responses() {
scenarios::drain_accepts_responses::<TcpFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_events() {
scenarios::drain_accepts_events::<TcpFactory>().await;
}
#[tokio::test]
async fn test_health_during_drain() {
scenarios::health_during_drain::<TcpFactory>().await;
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for TCP graceful shutdown
//!
//! These tests verify the 3-phase shutdown behavior:
//! 1. Gate: new Message frames are rejected with ShuttingDown
//! 2. Drain: in-flight work completes, responses/events still flow
//! 3. Teardown: listener and writer tasks exit
mod common;
use bytes::Bytes;
use std::time::Duration;
use tokio::time::{sleep, timeout};
use velo_transports::tcp::TcpFrameCodec;
use velo_transports::{MessageType, Transport};
use common::TestTransportHandle;
/// Helper: connect a raw TCP client to the transport's bind address and send a frame.
async fn connect_and_send_frame(
addr: std::net::SocketAddr,
msg_type: MessageType,
header: &[u8],
payload: &[u8],
) -> tokio::net::TcpStream {
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
TcpFrameCodec::encode_frame(&mut stream, msg_type, header, payload)
.await
.unwrap();
stream
}
/// Helper: read one frame from a raw TCP stream.
async fn read_one_frame(stream: &mut tokio::net::TcpStream) -> (MessageType, Bytes, Bytes) {
use futures::StreamExt;
use tokio_util::codec::Framed;
let mut framed = Framed::new(stream, TcpFrameCodec::new());
framed.next().await.unwrap().unwrap()
}
/// Get the bind address from a TcpTransport by parsing its WorkerAddress.
fn get_bind_addr(
handle: &TestTransportHandle<velo_transports::tcp::TcpTransport>,
) -> std::net::SocketAddr {
let addr = handle.transport.address();
let key = handle.transport.key();
let endpoint = addr.get_entry(&key).unwrap().unwrap();
let s = std::str::from_utf8(&endpoint).unwrap();
let s = s.strip_prefix("tcp://").unwrap_or(s);
s.parse().unwrap()
}
// --- Test 18: Drain rejects Message frames ---
#[tokio::test]
async fn test_tcp_drain_rejects_messages() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
// Begin drain
handle.streams.shutdown_state.begin_drain();
// Give listener time to be ready
sleep(Duration::from_millis(50)).await;
// Connect and send a Message frame
let mut stream =
connect_and_send_frame(addr, MessageType::Message, b"req-header", b"req-payload").await;
// Should get ShuttingDown back
let (msg_type, header, payload) = read_one_frame(&mut stream).await;
assert_eq!(msg_type, MessageType::ShuttingDown);
assert_eq!(&header[..], b"req-header"); // Original header echoed back
assert_eq!(payload.len(), 0); // Empty payload
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 19: Drain accepts Response frames ---
#[tokio::test]
async fn test_tcp_drain_accepts_responses() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
// Begin drain
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Connect and send a Response frame
connect_and_send_frame(addr, MessageType::Response, b"resp-header", b"resp-payload").await;
// Should arrive on the response stream
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"resp-header");
assert_eq!(&payload[..], b"resp-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 20: Drain accepts Event frames ---
#[tokio::test]
async fn test_tcp_drain_accepts_events() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
connect_and_send_frame(addr, MessageType::Event, b"evt-header", b"evt-payload").await;
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.event_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"evt-header");
assert_eq!(&payload[..], b"evt-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 21: New connection during drain still accepts responses ---
#[tokio::test]
async fn test_tcp_new_connection_during_drain() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
// Begin drain BEFORE connecting
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Establish a NEW connection after drain starts
connect_and_send_frame(addr, MessageType::Response, b"new-resp", b"new-payload").await;
// Should arrive on the response stream
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"new-resp");
assert_eq!(&payload[..], b"new-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 22: ShuttingDown frame roundtrip ---
#[test]
fn test_shutting_down_frame_roundtrip() {
use bytes::BytesMut;
use tokio_util::codec::Decoder;
let header = b"correlation-header";
let payload = b"";
// Encode ShuttingDown frame
let mut buf = Vec::new();
TcpFrameCodec::encode_frame_sync(&mut buf, MessageType::ShuttingDown, header, payload).unwrap();
// Decode it
let mut codec = TcpFrameCodec::new();
let mut bytes = BytesMut::from(&buf[..]);
let (msg_type, decoded_header, decoded_payload) = codec.decode(&mut bytes).unwrap().unwrap();
assert_eq!(msg_type, MessageType::ShuttingDown);
assert_eq!(&decoded_header[..], header);
assert_eq!(decoded_payload.len(), 0);
}
// --- Test 23: Full graceful shutdown lifecycle ---
#[tokio::test]
async fn test_tcp_graceful_shutdown_lifecycle() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
// Verify normal operation: send a message, receive it
connect_and_send_frame(addr, MessageType::Message, b"normal-msg", b"normal-pay").await;
let (header, _payload) = timeout(
Duration::from_secs(2),
handle.streams.message_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"normal-msg");
// Acquire an InFlightGuard (simulate in-progress request)
let guard = handle.streams.shutdown_state.acquire();
assert_eq!(handle.streams.shutdown_state.in_flight_count(), 1);
// Begin drain (Phase 1)
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Verify new messages are rejected
let mut stream = connect_and_send_frame(addr, MessageType::Message, b"reject-me", b"").await;
let (msg_type, _, _) = read_one_frame(&mut stream).await;
assert_eq!(msg_type, MessageType::ShuttingDown);
// Verify responses still flow
connect_and_send_frame(addr, MessageType::Response, b"still-ok", b"data").await;
let (header, _) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"still-ok");
// Spawn graceful_shutdown in background (will block on drain since guard is held)
let shutdown_state = handle.streams.shutdown_state.clone();
let shutdown_handle = tokio::spawn(async move {
// Phase 2: wait for drain
shutdown_state.wait_for_drain().await;
// Phase 3: teardown
shutdown_state.teardown_token().cancel();
});
// Verify shutdown hasn't completed yet (guard still held)
sleep(Duration::from_millis(100)).await;
assert!(!shutdown_handle.is_finished());
// Drop guard → drain completes → teardown fires
drop(guard);
timeout(Duration::from_secs(2), shutdown_handle)
.await
.expect("shutdown should complete")
.unwrap();
assert!(
handle
.streams
.shutdown_state
.teardown_token()
.is_cancelled()
);
}
// --- Test 24: Shutdown timeout forces teardown ---
#[tokio::test]
async fn test_tcp_shutdown_timeout_forces_teardown() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
// Acquire guard and hold it
let _guard = handle.streams.shutdown_state.acquire();
let shutdown_state = handle.streams.shutdown_state.clone();
let shutdown_handle = tokio::spawn(async move {
shutdown_state.begin_drain();
// Phase 2: wait with short timeout
let _ =
tokio::time::timeout(Duration::from_millis(100), shutdown_state.wait_for_drain()).await;
// Phase 3: teardown (forced, guard still held)
shutdown_state.teardown_token().cancel();
});
timeout(Duration::from_secs(2), shutdown_handle)
.await
.expect("shutdown should complete via timeout")
.unwrap();
// Teardown should have fired even though guard is held
assert!(
handle
.streams
.shutdown_state
.teardown_token()
.is_cancelled()
);
// Guard is still held (not a problem — teardown was forced)
assert_eq!(handle.streams.shutdown_state.in_flight_count(), 1);
}
// --- Test 25: Outbound sends during drain ---
#[tokio::test]
async fn test_outbound_sends_during_drain() {
// Create two transports and register them as peers
let handle_a = TestTransportHandle::new_tcp().await.unwrap();
let handle_b = TestTransportHandle::new_tcp().await.unwrap();
handle_a.register_peer(&handle_b).unwrap();
handle_b.register_peer(&handle_a).unwrap();
// Begin drain on transport A
handle_a.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Send a Response from A to B (outbound sends should work during drain)
handle_a.send(
handle_b.instance_id,
b"response-hdr".to_vec(),
b"response-pay".to_vec(),
MessageType::Response,
);
// B should receive the response
let (header, payload) = timeout(
Duration::from_secs(2),
handle_b.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"response-hdr");
assert_eq!(&payload[..], b"response-pay");
handle_a.streams.shutdown_state.teardown_token().cancel();
handle_b.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 26: Connection writer exits on teardown ---
#[tokio::test]
async fn test_connection_writer_exits_on_teardown() {
let handle_a = TestTransportHandle::new_tcp().await.unwrap();
let handle_b = TestTransportHandle::new_tcp().await.unwrap();
handle_a.register_peer(&handle_b).unwrap();
// Send a message to establish the connection writer task
handle_a.send(
handle_b.instance_id,
b"setup".to_vec(),
b"data".to_vec(),
MessageType::Message,
);
// Wait for it to arrive
timeout(
Duration::from_secs(2),
handle_b.streams.message_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
// Shutdown transport A
handle_a.transport.shutdown();
// Give writer tasks time to exit
sleep(Duration::from_millis(200)).await;
// Sending should now fail (error handler gets invoked, not a panic)
handle_a.send(
handle_b.instance_id,
b"should-fail".to_vec(),
b"data".to_vec(),
MessageType::Message,
);
// Give time for async error path
sleep(Duration::from_millis(100)).await;
// The message either goes to error handler or is silently dropped
// (connection cleared during shutdown). Just verify no panic occurred.
handle_b.streams.shutdown_state.teardown_token().cancel();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment