Unverified Commit 008683d6 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: adding kvbm-engine (#6773)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent cf79c4fc
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "velo-transports"
version = "0.1.0"
edition.workspace = true
description.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
keywords.workspace = true
[features]
default = ["http", "nats", "grpc"]
http = ["dep:axum", "dep:reqwest"]
nats = ["dep:async-nats", "dep:bs58"]
grpc = ["dep:tonic", "dep:prost", "dep:tower", "dep:hyper", "dep:http", "dep:http-body", "dep:http-body-util", "dep:hyper-util", "dep:tokio-stream"]
[dependencies]
velo-common = { workspace = true }
anyhow = { workspace = true }
bytes = { workspace = true }
dashmap = { workspace = true }
# derive_builder = { workspace = true }
parking_lot.workspace = true
# serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
# uuid = { workspace = true, features = ["serde", "v4"] }
# xxhash-rust = { workspace = true }
# base64 = "0.22"
flume = "0.12.0"
futures = "0.3"
# hashbrown = "0.16"
# lru = { version = "0.16", features = ["hashbrown"]}
rmp-serde = "1.1"
# serde_bytes = "0.11"
socket2 = "0.6"
tokio-util = { version = "0.7", features = ["codec"] }
[target.'cfg(target_os = "linux")'.dependencies]
nix = { version = "0.30", features = ["sched"] }
# optional dependencies
axum = { version = "0.8", optional = true }
reqwest = { workspace = true , optional = true }
async-nats = { workspace = true, optional = true }
bs58 = { version = "0.5", optional = true }
prost = { version = "0.13", optional = true }
tonic = { version = "0.13.1", optional = true }
tower = { version = "0.5", optional = true }
hyper = { version = "1.0", optional = true }
http = { version = "1.0", optional = true }
http-body = { version = "1.0", optional = true }
http-body-util = { version = "0.1", optional = true }
hyper-util = { version = "0.1", optional = true, features = ["tokio", "server", "server-auto"] }
tokio-stream = { version = "0.1", optional = true, features = ["sync"] }
[dev-dependencies]
tokio = { workspace = true, features = ["test-util", "macros"] }
tower = "0.5"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[build-dependencies]
tonic-build = "0.13.1"
# velo-transports
Multi-transport active message routing for distributed systems.
## Overview
`velo-transports` abstracts TCP, HTTP, NATS, gRPC, and UCX behind a unified `Transport` trait. It provides:
- **Zero-copy `Bytes`** — inbound frames share the receive buffer via `bytes::Bytes` slicing
- **Fire-and-forget with error callbacks**`send_message` is non-blocking; failures are reported through `TransportErrorHandler`
- **Priority-based peer routing** — each peer is registered with all compatible transports; the highest-priority one becomes the primary
- **3-phase graceful shutdown** — Gate → Drain → Teardown with RAII in-flight tracking
- **Peer discovery via `WorkerAddress`** — MessagePack-encoded map of transport key → endpoint bytes
## Architecture
```mermaid
graph TD
VB[VeloBackend] -->|priority routing| TCP[TcpTransport]
VB -->|priority routing| HTTP[HttpTransport]
VB -->|priority routing| NATS[NatsTransport]
VB -->|priority routing| GRPC[GrpcTransport]
VB -->|priority routing| UCX[UcxTransport]
TCP -->|inbound| DS[DataStreams]
HTTP -->|inbound| DS
NATS -->|inbound| DS
GRPC -->|inbound| DS
UCX -->|inbound| DS
DS -->|message_stream| APP[Application]
DS -->|response_stream| APP
DS -->|event_stream| APP
```
`VeloBackend` is the central orchestrator. It starts each transport, builds a composite `WorkerAddress` advertising all endpoints, and manages the peer registry. Inbound frames arrive via three independent `flume` channels in `DataStreams`.
## Feature Flags
| Feature | Default | Dependencies | Description |
|---------|---------|-------------|-------------|
| `http` | ✓ | axum, reqwest | HTTP transport with Axum server |
| `nats` | ✓ | async-nats, bs58 | NATS pub-sub transport |
| `grpc` | ✓ | tonic, prost, tower | gRPC bidirectional streaming |
| `nixl` | ✗ | nixl-sys | UCX/RDMA transport via NIXL (future work) |
TCP transport is always available (no feature gate).
## Transport Summary
| Transport | Protocol | Framing | Key Properties |
|-----------|----------|---------|----------------|
| **TCP** | Raw TCP | 11-byte preamble + header + payload | Zero-copy codec, DashMap connection pool, CPU pinning, keepalive |
| **HTTP** | HTTP/1.1 POST | Base64 header in `X-Transport-Header`, raw body | Fire-and-forget (202 Accepted), Axum server, drain → 503 |
| **NATS** | NATS pub-sub | Base64 header in NATS HeaderMap, raw payload | Subject scheme `velo.{b58}.{type}`, request/reply health, drain via unsub |
| **gRPC** | HTTP/2 streaming | Protobuf `FramedData` wrapper (preamble + header + payload) | Bidirectional streaming, tonic channels, exponential backoff reconnect |
| **UCX** | RDMA Active Messages | 3 fixed lanes (msg/resp/event) | `LocalSet` thread for `Rc<Worker>`, lazy endpoints, zero-copy RDMA |
## Shutdown Model
Graceful shutdown follows three phases:
1. **Gate**`begin_drain()` flips an atomic flag. Transports reject new inbound `Message` frames (TCP sends `ShuttingDown` response, HTTP returns 503, NATS unsubscribes from message subject).
2. **Drain**`wait_for_drain()` blocks until all `InFlightGuard`s are dropped. Policy is either `WaitForever` or `Timeout(Duration)`.
3. **Teardown** — Cancel the teardown token, stopping all listener loops and writer tasks. Call `shutdown()` on each transport to clean up connections.
Response, Ack, and Event frames continue flowing during drain so in-flight work can complete.
## Wire Format (TCP)
```text
┌──────────────┬───────────┬──────────────┬───────────────┬────────┬─────────┐
│ version (2B) │ type (1B) │ hdr_len (4B) │ pay_len (4B) │ header │ payload │
│ u16 BE │ u8 │ u32 BE │ u32 BE │ bytes │ bytes │
└──────────────┴───────────┴──────────────┴───────────────┴────────┴─────────┘
```
- **version**: Schema version (currently 1)
- **type**: `Message(0)`, `Response(1)`, `Ack(2)`, `Event(3)`, `ShuttingDown(4)`
- **hdr_len / pay_len**: Lengths of the following header and payload sections
- **Max frame size**: 16 MB
The gRPC transport wraps the same preamble + header + payload in a Protobuf `FramedData` message.
# TCP Transport Design
## Overview
The TCP transport provides high-performance, zero-copy message delivery over raw TCP connections. It uses a custom 11-byte frame preamble for minimal overhead and supports CPU pinning for predictable latency.
## Connection Management
### DashMap Connection Pool
Peer connections are managed via two `DashMap` instances:
- `peers: DashMap<InstanceId, SocketAddr>` — registered peer addresses
- `connections: DashMap<InstanceId, ConnectionHandle>` — active writer task handles
Connections are established lazily on first `send_message()` call. Each connection spawns a dedicated writer task that owns the TCP stream.
### Writer Tasks
Each `ConnectionHandle` wraps a bounded `flume::Sender<SendTask>` (default capacity: 256). The send path:
1. **Fast path**: `try_send()` on existing connection — non-blocking, no allocation
2. **Slow path (full)**: `send_async()` via spawned task — applies backpressure
3. **Slow path (new)**: `get_or_create_connection()` — establishes TCP connection and spawns writer
The writer task loop:
```
recv_async(SendTask) → encode_frame(&mut stream, ...) → loop
```
## TcpFrameCodec
### Wire Format
```
[u16 BE: schema_version(1)] [u8: frame_type] [u32 BE: header_len] [u32 BE: payload_len] [header] [payload]
```
Total preamble: 11 bytes. Maximum frame: 16 MB.
### Decoder State Machine
The codec uses a two-state decoder for streaming TCP:
```
AwaitingHeader ──(11 bytes available)──→ AwaitingData ──(data available)──→ emit frame, reset
```
Zero-copy is achieved via `BytesMut::split_to().freeze()` — the output `Bytes` share the underlying receive buffer.
### Encoder
`encode_frame()` writes three segments via `write_all()`:
1. Preamble (11 bytes)
2. Header bytes
3. Payload bytes
`write_vectored()` is intentionally not used because it doesn't guarantee writing all bytes for payloads exceeding the kernel send buffer (~128KB).
## TCP Listener
### Frame Routing
Incoming frames are routed based on `MessageType`:
| MessageType | Target Stream |
|------------|---------------|
| Message | `message_stream` |
| Response | `response_stream` |
| Ack, Event | `event_stream` |
| ShuttingDown | `response_stream` (for correlation) |
### Drain Behavior
During drain (`ShutdownState::is_draining()`):
- **Message** frames are rejected: a `ShuttingDown` frame is sent back with the original header for correlation
- **Response/Ack/Event** frames pass through normally
### CPU Pinning (Linux)
`RuntimeConfig::CpuPin(cpu_id)` creates a single-threaded tokio runtime with the thread pinned to the specified CPU core via `nix::sched::sched_setaffinity`. This provides predictable latency by avoiding context switches.
On non-Linux platforms, `CpuPin` falls back to a regular single-threaded runtime with a warning.
## Socket Configuration
Both listener and writer sides configure:
- `TCP_NODELAY` — disable Nagle's algorithm for low-latency framing
- `SO_LINGER(1s)` — ensure clean socket shutdown
- `TCP_KEEPALIVE` — 60s idle time, 10s probe interval
- **Buffer sizes** — 1 MB send/receive buffers for high throughput
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package velo.streaming.v1;
// Velo streaming service for active messages
//
// This service uses bidirectional streaming to send pre-framed messages.
// Messages are wrapped in FramedData which contains our custom TCP frame format.
service VeloStreaming
{
// Bidirectional streaming RPC
// Client sends framed messages, server receives them
// The response stream is unused (empty) in our implementation
rpc Stream(stream FramedData) returns (stream FramedData);
}
// Wrapper for pre-framed message data
//
// The fields map to TCP frame segments:
// preamble: [u16: version][u8: msg_type][u32: header_len][u32: payload_len]
// header: [header bytes]
// payload: [payload bytes]
message FramedData
{
bytes preamble = 1;
bytes header = 2;
bytes payload = 3;
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Internal address builder for constructing WorkerAddress instances.
//!
//! This module provides the builder pattern for creating WorkerAddress instances
//! from transport-specific endpoint data. It is internal to velo-transports.
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::Arc;
use velo_common::{WorkerAddress, WorkerAddressError};
/// Builder for constructing WorkerAddress instances.
///
/// This provides a mutable interface for collecting transport endpoints
/// before encoding them into the immutable WorkerAddress format.
#[derive(Debug, Clone, Default)]
pub(crate) struct WorkerAddressBuilder {
entries: HashMap<String, Bytes>,
}
impl WorkerAddressBuilder {
/// Create a new empty builder.
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
/// Add a new entry to the map.
///
/// Returns an error if the key already exists.
pub fn add_entry(
&mut self,
key: impl Into<String>,
value: impl Into<Bytes>,
) -> Result<(), WorkerAddressError> {
let key = key.into();
if self.entries.contains_key(&key) {
return Err(WorkerAddressError::KeyExists(key));
}
self.entries.insert(key, value.into());
Ok(())
}
/// Check if a key exists in the map.
#[allow(dead_code)]
pub fn has_entry(&self, key: &str) -> bool {
self.entries.contains_key(key)
}
/// Get a reference to an entry's value.
#[allow(dead_code)]
pub fn get_entry(&self, key: &str) -> Option<&Bytes> {
self.entries.get(key)
}
/// Merge another WorkerAddress into this builder.
///
/// This decodes the other address and attempts to add all its entries to this builder.
/// If any key from the other address already exists in this builder, returns an error
/// and leaves the builder unchanged.
pub fn merge(&mut self, other: &WorkerAddress) -> Result<(), WorkerAddressError> {
let map = decode_to_map(other.as_bytes())?;
// First check if any keys would conflict
for key in map.keys() {
if self.entries.contains_key(key.as_ref()) {
return Err(WorkerAddressError::KeyExists(key.to_string()));
}
}
// All keys are unique, now add them
for (key, value) in map {
self.entries.insert(key.to_string(), value);
}
Ok(())
}
/// Build the WorkerAddress from this builder.
///
/// This encodes the map into MessagePack binary format.
pub fn build(self) -> Result<WorkerAddress, WorkerAddressError> {
// Convert HashMap<String, Bytes> to HashMap<String, Vec<u8>> for MessagePack
let serializable: HashMap<String, Vec<u8>> = self
.entries
.into_iter()
.map(|(k, v)| (k, v.to_vec()))
.collect();
// Encode to MessagePack
let encoded = rmp_serde::to_vec(&serializable)?;
Ok(WorkerAddress::from_encoded(encoded))
}
}
/// Decode WorkerAddress bytes from MessagePack into a map.
fn decode_to_map(bytes: &[u8]) -> Result<HashMap<Arc<str>, Bytes>, WorkerAddressError> {
if bytes.is_empty() {
return Err(WorkerAddressError::InvalidFormat("Empty bytes".to_string()));
}
// Decode MessagePack
let decoded: HashMap<String, Vec<u8>> = rmp_serde::from_slice(bytes)?;
// Convert to HashMap<Arc<str>, Bytes>
Ok(decoded
.into_iter()
.map(|(k, v)| (Arc::from(k.as_str()), Bytes::from(v)))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_basic() {
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("endpoint", Bytes::from_static(b"tcp://127.0.0.1:5555"))
.unwrap();
builder
.add_entry("protocol", Bytes::from_static(b"tcp"))
.unwrap();
assert!(builder.has_entry("endpoint"));
assert!(builder.has_entry("protocol"));
assert!(!builder.has_entry("nonexistent"));
let address = builder.build().unwrap();
assert!(!address.as_bytes().is_empty());
// Verify we can read the entries back
let entry = address.get_entry("endpoint").unwrap();
assert_eq!(entry, Some(Bytes::from_static(b"tcp://127.0.0.1:5555")));
}
#[test]
fn test_builder_add_duplicate_key() {
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("key", Bytes::from_static(b"value1"))
.unwrap();
let result = builder.add_entry("key", Bytes::from_static(b"value2"));
assert!(matches!(result, Err(WorkerAddressError::KeyExists(_))));
}
#[test]
fn test_builder_merge() {
// Build first address
let mut builder1 = WorkerAddressBuilder::new();
builder1
.add_entry("tcp", Bytes::from_static(b"tcp://127.0.0.1:5555"))
.unwrap();
let address1 = builder1.build().unwrap();
// Build second address
let mut builder2 = WorkerAddressBuilder::new();
builder2
.add_entry("rdma", Bytes::from_static(b"rdma://10.0.0.1:6666"))
.unwrap();
let address2 = builder2.build().unwrap();
// Merge both into a new builder
let mut builder3 = WorkerAddressBuilder::new();
builder3.merge(&address1).unwrap();
builder3.merge(&address2).unwrap();
let final_address = builder3.build().unwrap();
// Verify both entries are present
assert_eq!(
final_address.get_entry("tcp").unwrap(),
Some(Bytes::from_static(b"tcp://127.0.0.1:5555"))
);
assert_eq!(
final_address.get_entry("rdma").unwrap(),
Some(Bytes::from_static(b"rdma://10.0.0.1:6666"))
);
}
#[test]
fn test_builder_merge_with_conflict() {
let mut builder1 = WorkerAddressBuilder::new();
builder1
.add_entry("tcp", Bytes::from_static(b"tcp://127.0.0.1:5555"))
.unwrap();
let address1 = builder1.build().unwrap();
let mut builder2 = WorkerAddressBuilder::new();
builder2
.add_entry("tcp", Bytes::from_static(b"tcp://different:5555"))
.unwrap();
let address2 = builder2.build().unwrap();
// Merge first address
let mut builder3 = WorkerAddressBuilder::new();
builder3.merge(&address1).unwrap();
// Try to merge conflicting address - should fail
let result = builder3.merge(&address2);
assert!(matches!(result, Err(WorkerAddressError::KeyExists(_))));
// Builder should be unchanged
assert!(builder3.has_entry("tcp"));
assert_eq!(
builder3.get_entry("tcp").unwrap(),
&Bytes::from_static(b"tcp://127.0.0.1:5555")
);
}
#[test]
fn test_empty_builder() {
let builder = WorkerAddressBuilder::new();
let address = builder.build().unwrap();
// Empty address should still be valid
let transports = address.available_transports().unwrap();
assert_eq!(transports.len(), 0);
}
// ========================================================================
// Integration tests: Verify WorkerAddressBuilder (velo-transports) produces
// addresses that WorkerAddress (velo-common) can correctly decode.
// These tests ensure the two crates stay in sync on the wire format.
// ========================================================================
#[test]
fn test_builder_address_integration_get_entry() {
// Build an address with multiple entries
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("tcp", Bytes::from_static(b"tcp://127.0.0.1:5555"))
.unwrap();
builder
.add_entry("rdma", Bytes::from_static(b"rdma://10.0.0.1:6666"))
.unwrap();
builder
.add_entry("binary_data", Bytes::from_static(&[0x00, 0x01, 0x02, 0xFF]))
.unwrap();
let address = builder.build().unwrap();
// Verify WorkerAddress::get_entry() correctly decodes each entry
assert_eq!(
address.get_entry("tcp").unwrap(),
Some(Bytes::from_static(b"tcp://127.0.0.1:5555"))
);
assert_eq!(
address.get_entry("rdma").unwrap(),
Some(Bytes::from_static(b"rdma://10.0.0.1:6666"))
);
assert_eq!(
address.get_entry("binary_data").unwrap(),
Some(Bytes::from_static(&[0x00, 0x01, 0x02, 0xFF]))
);
assert_eq!(address.get_entry("nonexistent").unwrap(), None);
}
#[test]
fn test_builder_address_integration_available_transports() {
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("tcp", Bytes::from_static(b"tcp://127.0.0.1:5555"))
.unwrap();
builder
.add_entry("rdma", Bytes::from_static(b"rdma://10.0.0.1:6666"))
.unwrap();
builder
.add_entry("grpc", Bytes::from_static(b"grpc://localhost:9000"))
.unwrap();
let address = builder.build().unwrap();
// Verify WorkerAddress::available_transports() returns all keys
let transports = address.available_transports().unwrap();
assert_eq!(transports.len(), 3);
assert!(transports.contains(&velo_common::TransportKey::from("tcp")));
assert!(transports.contains(&velo_common::TransportKey::from("rdma")));
assert!(transports.contains(&velo_common::TransportKey::from("grpc")));
}
#[test]
fn test_builder_address_integration_checksum_stability() {
// Build same address twice - checksums should match
let mut builder1 = WorkerAddressBuilder::new();
builder1
.add_entry("key", Bytes::from_static(b"value"))
.unwrap();
let address1 = builder1.build().unwrap();
let mut builder2 = WorkerAddressBuilder::new();
builder2
.add_entry("key", Bytes::from_static(b"value"))
.unwrap();
let address2 = builder2.build().unwrap();
// Same content should produce same checksum
assert_eq!(address1.checksum(), address2.checksum());
// Different content should produce different checksum
let mut builder3 = WorkerAddressBuilder::new();
builder3
.add_entry("key", Bytes::from_static(b"different"))
.unwrap();
let address3 = builder3.build().unwrap();
assert_ne!(address1.checksum(), address3.checksum());
}
#[test]
fn test_builder_address_integration_bytes_roundtrip() {
// Build an address
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("endpoint", Bytes::from_static(b"test://value"))
.unwrap();
let address = builder.build().unwrap();
// Get raw bytes and create new address via from_encoded
let raw_bytes = address.to_bytes();
let address2 = WorkerAddress::from_encoded(raw_bytes);
// Both should be equal and decode the same
assert_eq!(address, address2);
assert_eq!(address.checksum(), address2.checksum());
assert_eq!(
address.get_entry("endpoint").unwrap(),
address2.get_entry("endpoint").unwrap()
);
}
#[test]
fn test_builder_address_integration_serde_roundtrip() {
// Build an address
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("tcp", Bytes::from_static(b"tcp://127.0.0.1:5555"))
.unwrap();
let address = builder.build().unwrap();
// Serialize to JSON and back
let json = serde_json::to_string(&address).unwrap();
let deserialized: WorkerAddress = serde_json::from_str(&json).unwrap();
// Should be equal and decode correctly
assert_eq!(address, deserialized);
assert_eq!(
address.get_entry("tcp").unwrap(),
deserialized.get_entry("tcp").unwrap()
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#![deny(missing_docs)]
//! Multi-transport active message routing framework.
//!
//! `velo-transports` abstracts TCP, HTTP, NATS, gRPC, and UCX behind a unified
//! [`Transport`] trait with zero-copy [`bytes::Bytes`], fire-and-forget error
//! callbacks, priority-based peer routing, and 3-phase graceful shutdown.
//!
//! # Architecture
//!
//! [`VeloBackend`] is the central orchestrator. It holds a set of transports,
//! each identified by a [`TransportKey`]. When a peer registers, the backend
//! selects a *primary* transport (highest-priority compatible transport) and
//! records any alternatives. Outbound messages are routed through the primary
//! transport by default, or through an explicit alternative.
//!
//! Inbound messages arrive via [`DataStreams`] — three independent channels
//! for messages, responses, and events.
//!
//! # Shutdown
//!
//! Graceful shutdown follows three phases:
//! 1. **Gate** — flip the draining flag; transports reject new inbound requests.
//! 2. **Drain** — wait for all in-flight requests to complete.
//! 3. **Teardown** — cancel listeners/writers and call `shutdown()` on each transport.
mod address;
pub mod tcp;
#[cfg(unix)]
pub mod uds;
// #[cfg(feature = "ucx")]
// pub mod ucx;
// #[cfg(feature = "http")]
// pub mod http;
// #[cfg(feature = "nats")]
// pub mod nats;
// #[cfg(feature = "grpc")]
// pub mod grpc;
mod transport;
use std::{collections::HashMap, sync::Arc};
use dashmap::DashMap;
use parking_lot::Mutex;
// Public re-exports from velo-common
pub use velo_common::{
InstanceId, PeerInfo, TransportKey, WorkerAddress, WorkerAddressError, WorkerId,
};
// Internal builder for address construction
use address::WorkerAddressBuilder;
// Re-export transport types
pub use transport::{
DataStreams, HealthCheckError, InFlightGuard, MessageType, ShutdownPolicy, ShutdownState,
Transport, TransportAdapter, TransportError, TransportErrorHandler, make_channels,
};
/// Errors returned by [`VeloBackend`] operations.
#[derive(Debug, thiserror::Error)]
pub enum VeloBackendError {
/// No transport could accept the peer's address.
#[error("No compatible transports found")]
NoCompatibleTransports,
/// The target instance was never registered via [`VeloBackend::register_peer`].
#[error("Transport not found for instance: {0}")]
InstanceNotRegistered(InstanceId),
/// The worker ID is not in the fast-path cache.
#[error("Worker not found: {0}")]
WorkerNotRegistered(WorkerId),
/// The requested [`TransportKey`] does not match any loaded transport.
#[error("Transport not found: {0}")]
TransportNotFound(TransportKey),
/// The priority list does not match the set of available transports.
#[error("Invalid transport priority: {0}")]
InvalidTransportPriority(String),
}
/// Central orchestrator that aggregates multiple transports and routes messages
/// to peers via priority-based transport selection.
///
/// Each peer is registered with all compatible transports; the highest-priority
/// compatible transport becomes the *primary* for that peer. Worker IDs are
/// cached for fast-path routing without discovery lookups.
pub struct VeloBackend {
instance_id: InstanceId,
address: WorkerAddress,
priorities: Mutex<Vec<TransportKey>>,
transports: HashMap<TransportKey, Arc<dyn Transport>>,
primary_transport: DashMap<InstanceId, Arc<dyn Transport>>,
alternative_transports: DashMap<InstanceId, Vec<TransportKey>>,
workers: DashMap<WorkerId, InstanceId>,
shutdown_state: ShutdownState,
#[allow(dead_code)]
runtime: tokio::runtime::Handle,
}
impl VeloBackend {
/// Create a new backend from a list of transports.
///
/// Each transport is started (bound, listening) and its address is merged
/// into a composite [`WorkerAddress`]. Returns the backend and the
/// [`DataStreams`] receivers for inbound messages.
pub async fn new(
backend_transports: Vec<Arc<dyn Transport>>,
) -> anyhow::Result<(Self, DataStreams)> {
let instance_id = InstanceId::new_v4();
// build worker address
let mut priorities = Vec::new();
let mut builder = WorkerAddressBuilder::new();
let mut transports = HashMap::new();
let (adapter, data_streams) = transport::make_channels();
let shutdown_state = adapter.shutdown_state.clone();
let runtime = tokio::runtime::Handle::current();
for transport in backend_transports {
transport
.start(instance_id, adapter.clone(), runtime.clone())
.await?;
builder.merge(&transport.address())?;
priorities.push(transport.key());
transports.insert(transport.key(), transport);
}
let address = builder.build()?;
Ok((
Self {
instance_id,
address,
transports,
priorities: Mutex::new(priorities),
primary_transport: DashMap::new(),
alternative_transports: DashMap::new(),
workers: DashMap::new(),
shutdown_state,
runtime,
},
data_streams,
))
}
/// Returns this backend's unique instance identifier.
pub fn instance_id(&self) -> InstanceId {
self.instance_id
}
/// Returns a [`PeerInfo`] describing this backend (instance ID + composite address).
pub fn peer_info(&self) -> PeerInfo {
PeerInfo::new(self.instance_id, self.address.clone())
}
/// Returns `true` if the given instance has been registered via [`register_peer`](Self::register_peer).
pub fn is_registered(&self, instance_id: InstanceId) -> bool {
self.primary_transport.contains_key(&instance_id)
}
/// Fast-path lookup of worker_id -> instance_id from cache.
///
/// Returns `WorkerNotRegistered` if the worker is not in the cache.
/// Higher layers (Velo, VeloEvents, ActiveMessageClient) should handle
/// discovery fallback when this returns an error.
///
/// # Example
/// ```ignore
/// match backend.try_translate_worker_id(worker_id) {
/// Ok(instance_id) => { /* fast path: send immediately */ }
/// Err(VeloBackendError::WorkerNotRegistered(_)) => {
/// /* slow path: query discovery, then register_peer() */
/// }
/// }
/// ```
pub fn try_translate_worker_id(
&self,
worker_id: WorkerId,
) -> Result<InstanceId, VeloBackendError> {
self.workers
.get(&worker_id)
.map(|entry| *entry)
.ok_or(VeloBackendError::WorkerNotRegistered(worker_id))
}
/// Deprecated: Use `try_translate_worker_id()` for explicit fast-path semantics.
#[deprecated(since = "0.7.0", note = "Use try_translate_worker_id() instead")]
pub fn translate_worker_id(&self, worker_id: WorkerId) -> Result<InstanceId, VeloBackendError> {
self.try_translate_worker_id(worker_id)
}
/// Check if an instance_id is registered.
pub fn has_instance(&self, instance_id: InstanceId) -> bool {
self.primary_transport.contains_key(&instance_id)
}
/// Send a message to a registered peer via its primary transport.
///
/// Returns [`VeloBackendError::InstanceNotRegistered`] if the peer has not
/// been registered with [`register_peer`](Self::register_peer).
pub fn send_message(
&self,
target: InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: Arc<dyn TransportErrorHandler>,
) -> anyhow::Result<()> {
let transport = self
.primary_transport
.get(&target)
.ok_or(VeloBackendError::InstanceNotRegistered(target))?;
transport.send_message(target, header, payload, message_type, on_error);
Ok(())
}
/// Send a message to a registered peer via a specific transport.
///
/// If `transport_key` matches the peer's primary transport, the message is
/// sent directly. Otherwise, the alternative transports are searched.
/// Returns [`VeloBackendError::NoCompatibleTransports`] if the requested
/// transport is not available for this peer.
pub fn send_message_with_transport(
&self,
target: InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: Arc<dyn TransportErrorHandler>,
transport_key: TransportKey,
) -> anyhow::Result<()> {
let transport = self
.primary_transport
.get(&target)
.ok_or(VeloBackendError::InstanceNotRegistered(target))?;
if transport.value().key() == transport_key {
transport.send_message(target, header, payload, message_type, on_error);
return Ok(());
} else {
// if we got here, we can unwrap because there is an entry in the alternative_transports map
let alternative_transports = self
.alternative_transports
.get(&target)
.ok_or(VeloBackendError::InstanceNotRegistered(target))?;
for alternative_transport in alternative_transports.iter() {
if *alternative_transport == transport_key
&& let Some(transport) = self.transports.get(alternative_transport)
{
transport.send_message(target, header, payload, message_type, on_error);
return Ok(());
}
}
}
Err(VeloBackendError::NoCompatibleTransports)?
}
/// Send message to a worker (fast-path only).
///
/// This method uses `try_translate_worker_id()` for fast-path lookup.
/// Returns `WorkerNotRegistered` error if the worker is not in the cache.
///
/// For automatic discovery, use the two-phase pattern:
/// ```ignore
/// match backend.send_message_to_worker(...) {
/// Ok(()) => { /* success */ }
/// Err(e) if matches_worker_not_registered(&e) => {
/// tokio::spawn(async move {
/// let instance_id = backend.resolve_and_register_worker(worker_id).await?;
/// backend.send_message(instance_id, ...)?;
/// });
/// }
/// }
/// ```
pub fn send_message_to_worker(
&self,
worker_id: WorkerId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: Arc<dyn TransportErrorHandler>,
) -> anyhow::Result<()> {
let instance_id = self.try_translate_worker_id(worker_id)?;
self.send_message(instance_id, header, payload, message_type, on_error)
}
/// Register a remote peer with all compatible transports.
///
/// The highest-priority compatible transport becomes the peer's *primary*.
/// Returns [`VeloBackendError::NoCompatibleTransports`] if no transport
/// can accept the peer's address.
pub fn register_peer(&self, peer: PeerInfo) -> Result<(), VeloBackendError> {
// try to register the peer with each transport
// we must have at least one compatible transport; otherwise, return an error
let instance_id = peer.instance_id();
let mut compatible_transports = Vec::new();
for (key, transport) in self.transports.iter() {
if transport.register(peer.clone()).is_ok() {
compatible_transports.push(key.clone());
}
}
if compatible_transports.is_empty() {
return Err(VeloBackendError::NoCompatibleTransports);
}
// sort against the preferred transports
let sorted_transports = self
.priorities
.lock()
.iter()
.filter(|key| compatible_transports.contains(key))
.cloned()
.collect::<Vec<TransportKey>>();
assert!(
!sorted_transports.is_empty(),
"failed to properly sort compatible transports"
);
let primary_transport_key = sorted_transports[0].clone();
let alternative_transport_keys = sorted_transports[1..].to_vec();
let primary_transport = self.transports.get(&primary_transport_key).unwrap();
self.primary_transport
.insert(instance_id, primary_transport.clone());
self.alternative_transports
.insert(instance_id, alternative_transport_keys);
self.workers.insert(instance_id.worker_id(), instance_id);
Ok(())
}
/// Get the available transports.
pub fn available_transports(&self) -> Vec<TransportKey> {
self.transports.keys().cloned().collect()
}
/// Set the priority of the transports.
///
/// The list of [`TransportKey`]s must be an order set of the available transports.
pub fn set_transport_priority(
&self,
priorities: Vec<TransportKey>,
) -> Result<(), VeloBackendError> {
let required_transports = self.available_transports();
if required_transports.len() != priorities.len() {
return Err(VeloBackendError::InvalidTransportPriority(format!(
"Required transports: {:?}, provided priorities: {:?}",
required_transports, priorities
)));
}
for priority in &priorities {
if !required_transports.contains(priority) {
return Err(VeloBackendError::InvalidTransportPriority(format!(
"Priority transport not found: {:?}",
priority
)));
}
}
let mut guard = self.priorities.lock();
*guard = priorities;
Ok(())
}
/// Get the shared shutdown state.
pub fn shutdown_state(&self) -> &ShutdownState {
&self.shutdown_state
}
/// Perform a graceful 3-phase shutdown.
///
/// 1. **Gate**: Flip the draining flag and notify each transport via `begin_drain()`.
/// 2. **Drain**: Wait for all in-flight requests to complete (per `policy`).
/// 3. **Teardown**: Cancel the teardown token and call `shutdown()` on each transport.
pub async fn graceful_shutdown(&self, policy: ShutdownPolicy) {
// Phase 1: Gate
self.shutdown_state.begin_drain();
for transport in self.transports.values() {
transport.begin_drain();
}
// Phase 2: Drain
match policy {
ShutdownPolicy::WaitForever => {
self.shutdown_state.wait_for_drain().await;
}
ShutdownPolicy::Timeout(duration) => {
let _ = tokio::time::timeout(duration, self.shutdown_state.wait_for_drain()).await;
}
}
// Phase 3: Teardown
self.shutdown_state.teardown_token().cancel();
for transport in self.transports.values() {
transport.shutdown();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use futures::future::BoxFuture;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
/// Mock transport for testing VeloBackend logic without real networking.
struct MockTransport {
key: TransportKey,
address: WorkerAddress,
accept_register: bool,
started: AtomicBool,
drained: AtomicBool,
shut_down: AtomicBool,
send_count: AtomicUsize,
}
impl MockTransport {
fn new(key: &str, accept_register: bool) -> Arc<Self> {
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry(key, format!("mock://{}", key).into_bytes())
.unwrap();
let address = builder.build().unwrap();
Arc::new(Self {
key: TransportKey::from(key),
address,
accept_register,
started: AtomicBool::new(false),
drained: AtomicBool::new(false),
shut_down: AtomicBool::new(false),
send_count: AtomicUsize::new(0),
})
}
}
impl Transport for MockTransport {
fn key(&self) -> TransportKey {
self.key.clone()
}
fn address(&self) -> WorkerAddress {
self.address.clone()
}
fn register(&self, _peer_info: PeerInfo) -> Result<(), TransportError> {
if self.accept_register {
Ok(())
} else {
Err(TransportError::NoEndpoint)
}
}
fn send_message(
&self,
_instance_id: InstanceId,
_header: Vec<u8>,
_payload: Vec<u8>,
_message_type: MessageType,
_on_error: Arc<dyn TransportErrorHandler>,
) {
self.send_count.fetch_add(1, Ordering::Relaxed);
}
fn start(
&self,
_instance_id: InstanceId,
_channels: TransportAdapter,
_rt: tokio::runtime::Handle,
) -> BoxFuture<'_, anyhow::Result<()>> {
self.started.store(true, Ordering::Relaxed);
Box::pin(async { Ok(()) })
}
fn shutdown(&self) {
self.shut_down.store(true, Ordering::Relaxed);
}
fn begin_drain(&self) {
self.drained.store(true, Ordering::Relaxed);
}
fn check_health(
&self,
_instance_id: InstanceId,
_timeout: Duration,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<(), transport::HealthCheckError>>
+ Send
+ '_,
>,
> {
Box::pin(async { Ok(()) })
}
}
struct NoopErrorHandler;
impl TransportErrorHandler for NoopErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, _error: String) {}
}
/// Helper: build a PeerInfo with entries for specified transport keys.
fn make_peer_info(keys: &[&str]) -> PeerInfo {
let instance_id = InstanceId::new_v4();
let mut builder = WorkerAddressBuilder::new();
for key in keys {
builder
.add_entry(*key, format!("mock://{}", key).into_bytes())
.unwrap();
}
let address = builder.build().unwrap();
PeerInfo::new(instance_id, address)
}
#[tokio::test]
async fn test_new_single_transport() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t.clone() as Arc<dyn Transport>])
.await
.unwrap();
assert!(t.started.load(Ordering::Relaxed));
// instance_id should be a valid v4 UUID (non-zero)
assert!(!backend.instance_id().as_bytes().iter().all(|&b| b == 0));
assert_eq!(backend.available_transports().len(), 1);
}
#[tokio::test]
async fn test_new_multiple_transports() {
let t1 = MockTransport::new("tcp", true);
let t2 = MockTransport::new("http", true);
let (backend, _streams) = VeloBackend::new(vec![
t1.clone() as Arc<dyn Transport>,
t2.clone() as Arc<dyn Transport>,
])
.await
.unwrap();
assert!(t1.started.load(Ordering::Relaxed));
assert!(t2.started.load(Ordering::Relaxed));
assert_eq!(backend.available_transports().len(), 2);
}
#[tokio::test]
async fn test_register_peer_selects_primary_by_priority() {
let t1 = MockTransport::new("tcp", true);
let t2 = MockTransport::new("http", true);
let (backend, _streams) = VeloBackend::new(vec![
t1.clone() as Arc<dyn Transport>,
t2.clone() as Arc<dyn Transport>,
])
.await
.unwrap();
let peer = make_peer_info(&["tcp", "http"]);
let peer_id = peer.instance_id();
backend.register_peer(peer).unwrap();
assert!(backend.is_registered(peer_id));
// Primary should be "tcp" (first in priority)
let primary = backend.primary_transport.get(&peer_id).unwrap();
assert_eq!(primary.value().key(), TransportKey::from("tcp"));
}
#[tokio::test]
async fn test_register_peer_no_compatible_transports() {
// Transport rejects all registrations
let t = MockTransport::new("tcp", false);
let (backend, _streams) = VeloBackend::new(vec![t as Arc<dyn Transport>])
.await
.unwrap();
let peer = make_peer_info(&["tcp"]);
let result = backend.register_peer(peer);
assert!(matches!(
result,
Err(VeloBackendError::NoCompatibleTransports)
));
}
#[tokio::test]
async fn test_register_peer_stores_worker_mapping() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t as Arc<dyn Transport>])
.await
.unwrap();
let peer = make_peer_info(&["tcp"]);
let peer_id = peer.instance_id();
let worker_id = peer_id.worker_id();
backend.register_peer(peer).unwrap();
let resolved = backend.try_translate_worker_id(worker_id).unwrap();
assert_eq!(resolved, peer_id);
}
#[tokio::test]
async fn test_send_message_routes_to_primary() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t.clone() as Arc<dyn Transport>])
.await
.unwrap();
let peer = make_peer_info(&["tcp"]);
let peer_id = peer.instance_id();
backend.register_peer(peer).unwrap();
backend
.send_message(
peer_id,
vec![1],
vec![2],
MessageType::Message,
Arc::new(NoopErrorHandler),
)
.unwrap();
assert_eq!(t.send_count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_send_message_unregistered_peer() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t as Arc<dyn Transport>])
.await
.unwrap();
let result = backend.send_message(
InstanceId::new_v4(),
vec![],
vec![],
MessageType::Message,
Arc::new(NoopErrorHandler),
);
assert!(result.is_err());
}
#[tokio::test]
async fn test_send_message_with_transport_primary_match() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t.clone() as Arc<dyn Transport>])
.await
.unwrap();
let peer = make_peer_info(&["tcp"]);
let peer_id = peer.instance_id();
backend.register_peer(peer).unwrap();
backend
.send_message_with_transport(
peer_id,
vec![1],
vec![2],
MessageType::Message,
Arc::new(NoopErrorHandler),
TransportKey::from("tcp"),
)
.unwrap();
assert_eq!(t.send_count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_send_message_with_transport_alternative() {
let t1 = MockTransport::new("tcp", true);
let t2 = MockTransport::new("http", true);
let (backend, _streams) = VeloBackend::new(vec![
t1.clone() as Arc<dyn Transport>,
t2.clone() as Arc<dyn Transport>,
])
.await
.unwrap();
let peer = make_peer_info(&["tcp", "http"]);
let peer_id = peer.instance_id();
backend.register_peer(peer).unwrap();
// Send via "http" (the alternative transport)
backend
.send_message_with_transport(
peer_id,
vec![1],
vec![2],
MessageType::Message,
Arc::new(NoopErrorHandler),
TransportKey::from("http"),
)
.unwrap();
assert_eq!(t2.send_count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_send_message_with_transport_not_found() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t as Arc<dyn Transport>])
.await
.unwrap();
let peer = make_peer_info(&["tcp"]);
let peer_id = peer.instance_id();
backend.register_peer(peer).unwrap();
let result = backend.send_message_with_transport(
peer_id,
vec![],
vec![],
MessageType::Message,
Arc::new(NoopErrorHandler),
TransportKey::from("grpc"),
);
assert!(result.is_err());
}
#[tokio::test]
async fn test_try_translate_worker_id_not_found() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t as Arc<dyn Transport>])
.await
.unwrap();
let result = backend.try_translate_worker_id(InstanceId::new_v4().worker_id());
assert!(matches!(
result,
Err(VeloBackendError::WorkerNotRegistered(_))
));
}
#[tokio::test]
async fn test_set_transport_priority_valid() {
let t1 = MockTransport::new("tcp", true);
let t2 = MockTransport::new("http", true);
let (backend, _streams) =
VeloBackend::new(vec![t1 as Arc<dyn Transport>, t2 as Arc<dyn Transport>])
.await
.unwrap();
// Reverse the priority
backend
.set_transport_priority(vec![TransportKey::from("http"), TransportKey::from("tcp")])
.unwrap();
}
#[tokio::test]
async fn test_set_transport_priority_wrong_length() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t as Arc<dyn Transport>])
.await
.unwrap();
let result = backend
.set_transport_priority(vec![TransportKey::from("tcp"), TransportKey::from("http")]);
assert!(matches!(
result,
Err(VeloBackendError::InvalidTransportPriority(_))
));
}
#[tokio::test]
async fn test_set_transport_priority_unknown_key() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t as Arc<dyn Transport>])
.await
.unwrap();
let result = backend.set_transport_priority(vec![TransportKey::from("unknown")]);
assert!(matches!(
result,
Err(VeloBackendError::InvalidTransportPriority(_))
));
}
#[tokio::test]
async fn test_graceful_shutdown_calls_all_transports() {
let t1 = MockTransport::new("tcp", true);
let t2 = MockTransport::new("http", true);
let (backend, _streams) = VeloBackend::new(vec![
t1.clone() as Arc<dyn Transport>,
t2.clone() as Arc<dyn Transport>,
])
.await
.unwrap();
backend
.graceful_shutdown(ShutdownPolicy::Timeout(Duration::from_millis(100)))
.await;
assert!(t1.drained.load(Ordering::Relaxed));
assert!(t2.drained.load(Ordering::Relaxed));
assert!(t1.shut_down.load(Ordering::Relaxed));
assert!(t2.shut_down.load(Ordering::Relaxed));
assert!(backend.shutdown_state().is_draining());
assert!(backend.shutdown_state().teardown_token().is_cancelled());
}
#[tokio::test]
async fn test_peer_info_roundtrip() {
let t = MockTransport::new("tcp", true);
let (backend, _streams) = VeloBackend::new(vec![t as Arc<dyn Transport>])
.await
.unwrap();
let info = backend.peer_info();
assert_eq!(info.instance_id(), backend.instance_id());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Zero-copy TCP framing codec for ActiveMessage transport
//!
//! Wire format (7-15 bytes overhead):
//! ```text
//! [u16 BE: schema_version][u8: frame_type][u32 BE: header_len][u32 BE: payload_len][header bytes][payload bytes]
//! ```
//!
//! The codec uses `BytesMut` for receiving and `Bytes` for output, enabling
//! zero-copy buffer slicing where header and payload share the underlying buffer.
use bytes::{Buf, Bytes, BytesMut};
use std::io;
use std::io::Write;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio_util::codec::Decoder;
use crate::MessageType;
/// Current schema version
const SCHEMA_VERSION_V1: u16 = 1;
/// Maximum frame size (16 MB)
const MAX_FRAME_SIZE: u32 = 16 * 1024 * 1024;
/// Minimum frame header size (version + type + 2 lengths)
const MIN_HEADER_SIZE: usize = 2 + 1 + 4 + 4; // 11 bytes
/// Zero-copy frame decoder for TCP transport
///
/// This decoder maintains state across multiple calls to support partial
/// frame reception. It decodes frames into (MessageType, header: Bytes, payload: Bytes)
/// where header and payload are zero-copy slices of the receive buffer.
#[derive(Debug, Clone)]
pub struct TcpFrameCodec {
state: DecodeState,
}
#[derive(Debug, Clone, Copy)]
enum DecodeState {
/// Waiting for frame header (version + type + lengths)
AwaitingHeader,
/// Waiting for frame data (header + payload), with known lengths
AwaitingData {
frame_type: MessageType,
header_len: u32,
payload_len: u32,
},
}
impl TcpFrameCodec {
/// Create a new frame codec
pub fn new() -> Self {
Self {
state: DecodeState::AwaitingHeader,
}
}
/// Build the frame preamble (metadata header)
///
/// Returns a fixed-size preamble containing version, message type, and lengths.
#[inline]
pub fn build_preamble(
msg_type: MessageType,
header_len: u32,
payload_len: u32,
) -> io::Result<[u8; MIN_HEADER_SIZE]> {
// Validate lengths before building preamble
Self::validate_lengths(header_len, payload_len)?;
let mut preamble = [0u8; MIN_HEADER_SIZE];
// Layout:
// [0..2) = version
// [2] = msg_type
// [3..7) = header_len
// [7..11)= payload_len (total 11 bytes)
preamble[0..2].copy_from_slice(&SCHEMA_VERSION_V1.to_be_bytes());
preamble[2] = msg_type.as_u8();
preamble[3..7].copy_from_slice(&header_len.to_be_bytes());
preamble[7..11].copy_from_slice(&payload_len.to_be_bytes());
Ok(preamble)
}
/// Parse message type from a preamble
///
/// Validates the schema version and extracts the message type from the preamble.
#[inline]
pub fn parse_message_type_from_preamble(preamble: &[u8]) -> io::Result<MessageType> {
if preamble.len() < MIN_HEADER_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Preamble too short",
));
}
// Validate schema version
let schema_version = u16::from_be_bytes([preamble[0], preamble[1]]);
if schema_version != SCHEMA_VERSION_V1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unsupported schema version: {} (expected {})",
schema_version, SCHEMA_VERSION_V1
),
));
}
// Extract and validate message type
MessageType::from_u8(preamble[2]).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid message type: {}", preamble[2]),
)
})
}
/// Encode and write a frame asynchronously
///
/// Uses `write_all()` for each segment to handle partial writes correctly.
/// TCP `write_vectored()` doesn't guarantee writing all bytes in one call —
/// for payloads exceeding the kernel send buffer (~128KB), it returns a short
/// write count. Using `write_all()` per segment ensures correctness for all sizes.
#[inline]
pub async fn encode_frame<W: AsyncWrite + Unpin>(
writer: &mut W,
msg_type: MessageType,
header: &[u8],
payload: &[u8],
) -> tokio::io::Result<()> {
let preamble = Self::build_preamble(msg_type, header.len() as u32, payload.len() as u32)?;
writer.write_all(&preamble).await?;
writer.write_all(header).await?;
writer.write_all(payload).await?;
Ok(())
}
/// Encode and write a frame synchronously
///
/// Uses `write_all()` for each segment to handle partial writes correctly.
#[inline]
pub fn encode_frame_sync<W: Write>(
writer: &mut W,
msg_type: MessageType,
header: &[u8],
payload: &[u8],
) -> std::io::Result<()> {
let preamble = Self::build_preamble(msg_type, header.len() as u32, payload.len() as u32)?;
writer.write_all(&preamble)?;
writer.write_all(header)?;
writer.write_all(payload)?;
Ok(())
}
/// Validate that lengths are reasonable
fn validate_lengths(header_len: u32, payload_len: u32) -> io::Result<()> {
let total_len = header_len
.checked_add(payload_len)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Frame size overflow"))?;
if total_len > MAX_FRAME_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Frame size {} exceeds maximum {}",
total_len, MAX_FRAME_SIZE
),
));
}
Ok(())
}
}
impl Default for TcpFrameCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for TcpFrameCodec {
type Item = (MessageType, Bytes, Bytes);
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
match self.state {
DecodeState::AwaitingHeader => {
// Need at least MIN_HEADER_SIZE bytes
if src.len() < MIN_HEADER_SIZE {
return Ok(None);
}
// Parse header without consuming bytes yet
let schema_version = u16::from_be_bytes([src[0], src[1]]);
let frame_type_byte = src[2];
let header_len = u32::from_be_bytes([src[3], src[4], src[5], src[6]]);
let payload_len = u32::from_be_bytes([src[7], src[8], src[9], src[10]]);
// Validate schema version
if schema_version != SCHEMA_VERSION_V1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unsupported schema version: {} (expected {})",
schema_version, SCHEMA_VERSION_V1
),
));
}
// Parse frame type
let frame_type = MessageType::from_u8(frame_type_byte).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid frame type: {}", frame_type_byte),
)
})?;
// Validate lengths before allocating/waiting
Self::validate_lengths(header_len, payload_len)?;
// Advance buffer past header
src.advance(MIN_HEADER_SIZE);
// Transition to data state
self.state = DecodeState::AwaitingData {
frame_type,
header_len,
payload_len,
};
}
DecodeState::AwaitingData {
frame_type,
header_len,
payload_len,
..
} => {
let total_data_len = (header_len + payload_len) as usize;
// Wait for full data
if src.len() < total_data_len {
return Ok(None);
}
// Zero-copy: split buffer into header and payload slices
let header = src.split_to(header_len as usize).freeze();
let payload = src.split_to(payload_len as usize).freeze();
// Reset state for next frame
self.state = DecodeState::AwaitingHeader;
return Ok(Some((frame_type, header, payload)));
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Test helper to encode a frame into a Vec<u8> for verification (async)
async fn encode_frame_to_bytes(
msg_type: MessageType,
header: &[u8],
payload: &[u8],
) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
TcpFrameCodec::encode_frame(&mut buf, msg_type, header, payload).await?;
Ok(buf)
}
/// Test helper to encode a frame into a Vec<u8> for verification (sync)
fn encode_frame_to_bytes_sync(
msg_type: MessageType,
header: &[u8],
payload: &[u8],
) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
TcpFrameCodec::encode_frame_sync(&mut buf, msg_type, header, payload)?;
Ok(buf)
}
/// Helper to create raw frames with arbitrary parameters for negative testing.
///
/// This function bypasses normal validation and encoding logic to create
/// intentionally invalid frames (wrong schema version, oversized frames, etc.)
/// for testing error handling paths. Use `encode_frame_to_bytes()` for
/// testing valid frame construction.
fn create_unsafe_frame(
schema_version: u16,
frame_type: MessageType,
header: &[u8],
payload: &[u8],
) -> BytesMut {
let mut buf = BytesMut::new();
buf.extend_from_slice(&schema_version.to_be_bytes());
buf.extend_from_slice(&[frame_type.as_u8()]);
buf.extend_from_slice(&(header.len() as u32).to_be_bytes());
buf.extend_from_slice(&(payload.len() as u32).to_be_bytes());
buf.extend_from_slice(header);
buf.extend_from_slice(payload);
buf
}
#[test]
fn test_decode_message_frame() {
let mut codec = TcpFrameCodec::new();
let header = b"test-header";
let payload = b"test-payload-data";
let framed = encode_frame_to_bytes_sync(MessageType::Message, header, payload).unwrap();
let mut buf = BytesMut::from(&framed[..]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_some());
let (msg_type, decoded_header, decoded_payload) = result.unwrap();
assert_eq!(msg_type, MessageType::Message);
assert_eq!(decoded_header, Bytes::from(header.as_ref()));
assert_eq!(decoded_payload, Bytes::from(payload.as_ref()));
}
#[test]
fn test_decode_all_frame_types() {
let frame_types = [
MessageType::Message,
MessageType::Response,
MessageType::Ack,
MessageType::Event,
];
for frame_type in &frame_types {
let mut codec = TcpFrameCodec::new();
let header = b"header";
let payload = b"payload";
let framed = encode_frame_to_bytes_sync(*frame_type, header, payload).unwrap();
let mut buf = BytesMut::from(&framed[..]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_some());
let (decoded_type, _, _) = result.unwrap();
assert_eq!(decoded_type, *frame_type);
}
}
#[test]
fn test_decode_empty_payload() {
let mut codec = TcpFrameCodec::new();
let header = b"ack-header";
let payload = b"";
let framed = encode_frame_to_bytes_sync(MessageType::Ack, header, payload).unwrap();
let mut buf = BytesMut::from(&framed[..]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_some());
let (msg_type, decoded_header, decoded_payload) = result.unwrap();
assert_eq!(msg_type, MessageType::Ack);
assert_eq!(&decoded_header[..], header);
assert_eq!(decoded_payload.len(), 0);
}
#[test]
fn test_decode_partial_frame() {
let mut codec = TcpFrameCodec::new();
let header = b"test-header";
let payload = b"test-payload";
let full_frame = encode_frame_to_bytes_sync(MessageType::Message, header, payload).unwrap();
// Send only first 5 bytes (partial header)
let mut buf = BytesMut::from(&full_frame[..5]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_none()); // Not enough data
// Send rest of header
buf.extend_from_slice(&full_frame[5..MIN_HEADER_SIZE]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_none()); // Header parsed, but data not yet available
// Send complete data
buf.extend_from_slice(&full_frame[MIN_HEADER_SIZE..]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_some());
let (msg_type, decoded_header, decoded_payload) = result.unwrap();
assert_eq!(msg_type, MessageType::Message);
assert_eq!(&decoded_header[..], header);
assert_eq!(&decoded_payload[..], payload);
}
#[test]
fn test_decode_invalid_schema_version() {
let mut codec = TcpFrameCodec::new();
let header = b"header";
let payload = b"payload";
let mut buf = create_unsafe_frame(999, MessageType::Message, header, payload);
let result = codec.decode(&mut buf);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Unsupported schema version")
);
}
#[test]
fn test_decode_invalid_frame_type() {
let mut codec = TcpFrameCodec::new();
let mut buf = BytesMut::new();
// Create frame with invalid type byte (255)
buf.extend_from_slice(&SCHEMA_VERSION_V1.to_be_bytes());
buf.extend_from_slice(&[255u8]); // Invalid frame type
buf.extend_from_slice(&10u32.to_be_bytes()); // header len
buf.extend_from_slice(&10u32.to_be_bytes()); // payload len
let result = codec.decode(&mut buf);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Invalid frame type")
);
}
#[test]
fn test_decode_frame_too_large() {
let mut codec = TcpFrameCodec::new();
let mut buf = BytesMut::new();
// Create frame that exceeds MAX_FRAME_SIZE
buf.extend_from_slice(&SCHEMA_VERSION_V1.to_be_bytes());
buf.extend_from_slice(&[MessageType::Message.as_u8()]);
buf.extend_from_slice(&(MAX_FRAME_SIZE / 2 + 1).to_be_bytes());
buf.extend_from_slice(&(MAX_FRAME_SIZE / 2 + 1).to_be_bytes());
let result = codec.decode(&mut buf);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
}
#[test]
fn test_decode_multiple_frames() {
let mut codec = TcpFrameCodec::new();
let mut buf = BytesMut::new();
// Add two frames to buffer
let frame1 =
encode_frame_to_bytes_sync(MessageType::Message, b"header1", b"payload1").unwrap();
let frame2 =
encode_frame_to_bytes_sync(MessageType::Response, b"header2", b"payload2").unwrap();
buf.extend_from_slice(&frame1);
buf.extend_from_slice(&frame2);
// Decode first frame
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_some());
let (msg_type, header, payload) = result.unwrap();
assert_eq!(msg_type, MessageType::Message);
assert_eq!(&header[..], b"header1");
assert_eq!(&payload[..], b"payload1");
// Decode second frame
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_some());
let (msg_type, header, payload) = result.unwrap();
assert_eq!(msg_type, MessageType::Response);
assert_eq!(&header[..], b"header2");
assert_eq!(&payload[..], b"payload2");
// No more frames
assert!(buf.is_empty());
}
#[test]
fn test_zero_copy_bytes_share_buffer() {
let mut codec = TcpFrameCodec::new();
let header = b"shared-header";
let payload = b"shared-payload";
let framed = encode_frame_to_bytes_sync(MessageType::Message, header, payload).unwrap();
let mut buf = BytesMut::from(&framed[..]);
let result = codec.decode(&mut buf).unwrap().unwrap();
let (_, decoded_header, decoded_payload) = result;
// Verify the slices contain correct data
assert_eq!(&decoded_header[..], header);
assert_eq!(&decoded_payload[..], payload);
// Clone should be cheap (just RC increment)
let header_clone = decoded_header.clone();
let payload_clone = decoded_payload.clone();
assert_eq!(decoded_header, header_clone);
assert_eq!(decoded_payload, payload_clone);
}
#[test]
fn test_encode_frame() {
let header = b"test-header";
let payload = b"test-payload";
let framed = encode_frame_to_bytes_sync(MessageType::Message, header, payload).unwrap();
// Verify frame structure
assert_eq!(framed.len(), MIN_HEADER_SIZE + header.len() + payload.len());
// Verify header fields
assert_eq!(
u16::from_be_bytes([framed[0], framed[1]]),
SCHEMA_VERSION_V1
);
assert_eq!(framed[2], MessageType::Message.as_u8());
assert_eq!(
u32::from_be_bytes([framed[3], framed[4], framed[5], framed[6]]),
header.len() as u32
);
assert_eq!(
u32::from_be_bytes([framed[7], framed[8], framed[9], framed[10]]),
payload.len() as u32
);
// Verify data
assert_eq!(
&framed[MIN_HEADER_SIZE..MIN_HEADER_SIZE + header.len()],
header
);
assert_eq!(&framed[MIN_HEADER_SIZE + header.len()..], payload);
}
#[test]
fn test_encode_all_message_types() {
let header = b"header";
let payload = b"payload";
for msg_type in &[
MessageType::Message,
MessageType::Response,
MessageType::Ack,
MessageType::Event,
] {
let framed = encode_frame_to_bytes_sync(*msg_type, header, payload).unwrap();
assert_eq!(framed[2], msg_type.as_u8());
}
}
#[test]
fn test_encode_empty_payload() {
let header = b"ack-header";
let payload = b"";
let framed = encode_frame_to_bytes_sync(MessageType::Ack, header, payload).unwrap();
assert_eq!(framed.len(), MIN_HEADER_SIZE + header.len());
assert_eq!(
u32::from_be_bytes([framed[7], framed[8], framed[9], framed[10]]),
0
);
}
#[test]
fn test_encode_frame_too_large() {
let header = vec![0u8; (MAX_FRAME_SIZE / 2 + 1) as usize];
let payload = vec![0u8; (MAX_FRAME_SIZE / 2 + 1) as usize];
let result = encode_frame_to_bytes_sync(MessageType::Message, &header, &payload);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
}
#[test]
fn test_round_trip_encode_decode() {
let mut codec = TcpFrameCodec::new();
let header = b"round-trip-header";
let payload = b"round-trip-payload-data";
// Encode
let framed = encode_frame_to_bytes_sync(MessageType::Response, header, payload).unwrap();
// Decode
let mut buf = BytesMut::from(&framed[..]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_some());
let (msg_type, decoded_header, decoded_payload) = result.unwrap();
assert_eq!(msg_type, MessageType::Response);
assert_eq!(&decoded_header[..], header);
assert_eq!(&decoded_payload[..], payload);
}
#[test]
fn test_round_trip_all_types() {
let types = [
MessageType::Message,
MessageType::Response,
MessageType::Ack,
MessageType::Event,
];
for msg_type in &types {
let mut codec = TcpFrameCodec::new();
let header = b"header";
let payload = b"payload";
let framed = encode_frame_to_bytes_sync(*msg_type, header, payload).unwrap();
let mut buf = BytesMut::from(&framed[..]);
let result = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(result.0, *msg_type);
assert_eq!(&result.1[..], header);
assert_eq!(&result.2[..], payload);
}
}
#[test]
fn test_encode_frame_sync() {
let header = b"sync-header";
let payload = b"sync-payload";
let framed = encode_frame_to_bytes_sync(MessageType::Message, header, payload).unwrap();
// Verify frame structure
assert_eq!(framed.len(), MIN_HEADER_SIZE + header.len() + payload.len());
// Verify preamble fields
assert_eq!(
u16::from_be_bytes([framed[0], framed[1]]),
SCHEMA_VERSION_V1
);
assert_eq!(framed[2], MessageType::Message.as_u8());
assert_eq!(
u32::from_be_bytes([framed[3], framed[4], framed[5], framed[6]]),
header.len() as u32
);
assert_eq!(
u32::from_be_bytes([framed[7], framed[8], framed[9], framed[10]]),
payload.len() as u32
);
// Verify data
assert_eq!(
&framed[MIN_HEADER_SIZE..MIN_HEADER_SIZE + header.len()],
header
);
assert_eq!(&framed[MIN_HEADER_SIZE + header.len()..], payload);
}
#[test]
fn test_sync_async_produce_same_output() {
let header = b"test-header";
let payload = b"test-payload";
// Encode with sync version
let sync_framed =
encode_frame_to_bytes_sync(MessageType::Response, header, payload).unwrap();
// Encode with async version (using tokio runtime)
let async_framed = tokio::runtime::Runtime::new()
.unwrap()
.block_on(encode_frame_to_bytes(
MessageType::Response,
header,
payload,
))
.unwrap();
// Both should produce identical output
assert_eq!(sync_framed, async_framed);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! High-performance TCP listener for ActiveMessage transport
//!
//! This module provides a TCP server that accepts incoming connections,
//! decodes framed messages using zero-copy techniques, and routes them
//! to the appropriate transport streams.
use anyhow::{Context, Result};
use bytes::Bytes;
use futures::StreamExt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener as TokioTcpListener, TcpStream};
use tokio::runtime::{Handle, Runtime};
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};
use crate::{MessageType, ShutdownState, TransportAdapter, TransportErrorHandler};
use super::framing::TcpFrameCodec;
/// Runtime configuration for the TCP listener
pub enum RuntimeConfig {
/// Use an existing tokio runtime handle.
Handle(Handle),
/// Use a provided tokio runtime.
Runtime(Arc<Runtime>),
/// Create a single-threaded runtime pinned to the specified CPU core (Linux only).
CpuPin(usize),
}
/// High-performance TCP listener for ActiveMessage transport
///
/// This listener accepts incoming TCP connections and routes decoded frames
/// to the appropriate transport streams with zero-copy performance.
pub struct TcpListener {
bind_addr: SocketAddr,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
runtime_config: RuntimeConfig,
listener: Option<std::net::TcpListener>,
}
impl TcpListener {
/// Create a new builder for TcpListener
pub fn builder() -> TcpListenerBuilder {
TcpListenerBuilder::new()
}
/// Start the listener and serve incoming connections
///
/// This method blocks or spawns based on the runtime configuration:
/// - For Handle/Runtime: spawns tasks and returns immediately
/// - For CpuPin: creates a pinned runtime and blocks until cancellation
pub async fn serve(mut self) -> Result<()> {
// Extract runtime config to avoid borrow issues
let runtime_config = std::mem::replace(
&mut self.runtime_config,
RuntimeConfig::Handle(Handle::current()),
);
match runtime_config {
RuntimeConfig::Handle(handle) => {
handle.spawn(async move {
if let Err(e) = self.run_server().await {
error!("TCP listener error: {}", e);
}
});
Ok(())
}
RuntimeConfig::Runtime(rt) => {
rt.spawn(async move {
if let Err(e) = self.run_server().await {
error!("TCP listener error: {}", e);
}
});
Ok(())
}
RuntimeConfig::CpuPin(cpu_id) => {
let rt = Self::create_pinned_runtime(cpu_id)
.context("Failed to create CPU-pinned runtime")?;
rt.block_on(self.run_server())
}
}
}
/// Create a single-threaded runtime pinned to a specific CPU core
#[cfg(target_os = "linux")]
fn create_pinned_runtime(cpu_id: usize) -> Result<Runtime> {
use nix::sched::{CpuSet, sched_setaffinity};
use nix::unistd::Pid;
tokio::runtime::Builder::new_current_thread()
.enable_all()
.thread_name("tcp-listener-pinned")
.on_thread_start(move || {
let mut cpu_set = CpuSet::new();
if cpu_set.set(cpu_id).is_ok() {
if let Err(e) = sched_setaffinity(Pid::from_raw(0), &cpu_set) {
error!("Failed to pin thread to CPU {}: {}", cpu_id, e);
} else {
debug!("Successfully pinned TCP listener to CPU {}", cpu_id);
}
}
})
.build()
.context("Failed to build tokio runtime")
}
/// Create a single-threaded runtime without CPU pinning (non-Linux platforms)
#[cfg(not(target_os = "linux"))]
fn create_pinned_runtime(cpu_id: usize) -> Result<Runtime> {
warn!(
"CPU pinning requested (CPU {}) but not supported on this platform",
cpu_id
);
tokio::runtime::Builder::new_current_thread()
.enable_all()
.thread_name("tcp-listener")
.build()
.context("Failed to build tokio runtime")
}
/// Main server loop that accepts connections
async fn run_server(self) -> Result<()> {
// Use pre-bound listener if provided, otherwise bind to the address
let listener = if let Some(std_listener) = self.listener {
// Set non-blocking for tokio conversion
std_listener
.set_nonblocking(true)
.context("Failed to set listener to non-blocking")?;
TokioTcpListener::from_std(std_listener)
.context("Failed to convert std TcpListener to tokio TcpListener")?
} else {
TokioTcpListener::bind(self.bind_addr)
.await
.context(format!("Failed to bind TCP listener to {}", self.bind_addr))?
};
let local_addr = listener
.local_addr()
.context("Failed to get local address")?;
info!("TCP listener bound to {}", local_addr);
let teardown_token = self.shutdown_state.teardown_token().clone();
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, peer_addr)) => {
debug!("Accepted TCP connection from {}", peer_addr);
let adapter = self.adapter.clone();
let error_handler = self.error_handler.clone();
let shutdown_state = self.shutdown_state.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(
stream,
peer_addr,
adapter,
error_handler,
shutdown_state,
)
.await
{
warn!("Error handling connection from {}: {}", peer_addr, e);
}
});
}
Err(e) => {
error!("Failed to accept TCP connection: {}", e);
}
}
}
_ = teardown_token.cancelled() => {
info!("TCP listener shutting down (teardown)");
break;
}
}
}
Ok(())
}
/// Handle a single TCP connection
async fn handle_connection(
stream: TcpStream,
peer_addr: SocketAddr,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
) -> Result<()> {
debug!("Configuring connection from {}", peer_addr);
// Configure socket for high performance
if let Err(e) = stream.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY on {}: {}", peer_addr, e);
}
#[allow(deprecated)] // Intentional: linger ensures clean socket shutdown
if let Err(e) = stream.set_linger(Some(Duration::from_secs(1))) {
warn!("Failed to set linger on {}: {}", peer_addr, e);
}
// Set keep-alive to detect dead connections
let keepalive = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10));
let sock_ref = socket2::SockRef::from(&stream);
if let Err(e) = sock_ref.set_tcp_keepalive(&keepalive) {
warn!("Failed to set TCP keepalive on {}: {}", peer_addr, e);
}
// Set large receive buffer for high throughput
if let Err(e) = sock_ref.set_recv_buffer_size(1_048_576) {
warn!("Failed to set receive buffer size on {}: {}", peer_addr, e);
}
// Create framed stream with zero-copy codec
let mut framed = Framed::new(stream, TcpFrameCodec::new());
let teardown_token = shutdown_state.teardown_token().clone();
debug!("Connection from {} ready for frames", peer_addr);
loop {
tokio::select! {
frame_result = framed.next() => {
match frame_result {
Some(Ok((msg_type, header, payload))) => {
// During drain: reject new Message frames with ShuttingDown,
// but always pass through Response/Ack/Event frames.
if shutdown_state.is_draining() && msg_type == MessageType::Message {
debug!(
"Rejecting Message frame from {} during drain (sending ShuttingDown)",
peer_addr
);
// Echo original header back for correlation, empty payload
if let Err(e) = TcpFrameCodec::encode_frame(
framed.get_mut(),
MessageType::ShuttingDown,
&header,
&[],
)
.await
{
warn!(
"Failed to send ShuttingDown frame to {}: {}",
peer_addr, e
);
}
continue;
}
// Route frame to appropriate stream based on type
if let Err(e) = Self::route_frame(
msg_type,
header,
payload,
&adapter,
&error_handler,
)
.await
{
warn!(
"Failed to route {:?} frame from {}: {}",
msg_type, peer_addr, e
);
}
}
Some(Err(e)) => {
error!("Frame decode error from {}: {}", peer_addr, e);
break;
}
None => {
// Connection closed gracefully (FIN received)
debug!("Connection from {} closed gracefully", peer_addr);
break;
}
}
}
_ = teardown_token.cancelled() => {
debug!("Connection handler for {} torn down", peer_addr);
break;
}
}
}
Ok(())
}
/// Route a decoded frame to the appropriate stream
///
/// This function performs zero-copy routing by transferring ownership of
/// the Bytes to the flume channel. On error, it invokes the error callback
/// with the original data (requiring a clone).
async fn route_frame(
msg_type: MessageType,
header: Bytes,
payload: Bytes,
adapter: &TransportAdapter,
error_handler: &Arc<dyn TransportErrorHandler>,
) -> Result<()> {
let sender = match msg_type {
MessageType::Message => &adapter.message_stream,
MessageType::Response => &adapter.response_stream,
MessageType::Ack | MessageType::Event => &adapter.event_stream,
MessageType::ShuttingDown => {
// ShuttingDown is an outbound-only frame type; receiving it here
// means a remote peer rejected our request. Route to the response
// stream so higher layers can handle the rejection via correlation.
&adapter.response_stream
}
};
// Try to send with ownership transfer (zero-copy)
match sender.send_async((header, payload)).await {
Ok(_) => Ok(()),
Err(e) => {
// Send failed - invoke error callback with the data
error_handler.on_error(
e.0.0, // header
e.0.1, // payload
format!("Failed to route {:?}", msg_type),
);
Err(anyhow::anyhow!("Failed to send to stream"))
}
}
}
}
/// Builder for TcpListener
pub struct TcpListenerBuilder {
bind_addr: Option<SocketAddr>,
adapter: Option<TransportAdapter>,
error_handler: Option<Arc<dyn TransportErrorHandler>>,
shutdown_state: Option<ShutdownState>,
runtime_config: Option<RuntimeConfig>,
listener: Option<std::net::TcpListener>,
}
impl TcpListenerBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
bind_addr: None,
adapter: None,
error_handler: None,
shutdown_state: None,
runtime_config: None,
listener: None,
}
}
/// Set the bind address
pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
self.bind_addr = Some(addr);
self
}
/// Set the transport adapter
pub fn adapter(mut self, adapter: TransportAdapter) -> Self {
self.adapter = Some(adapter);
self
}
/// Set the error handler
pub fn error_handler(mut self, handler: Arc<dyn TransportErrorHandler>) -> Self {
self.error_handler = Some(handler);
self
}
/// Set the shutdown state for graceful drain coordination
pub fn shutdown_state(mut self, state: ShutdownState) -> Self {
self.shutdown_state = Some(state);
self
}
/// Use an existing tokio runtime handle
pub fn with_handle(mut self, handle: Handle) -> Self {
self.runtime_config = Some(RuntimeConfig::Handle(handle));
self
}
/// Use a provided tokio runtime
pub fn with_runtime(mut self, runtime: Arc<Runtime>) -> Self {
self.runtime_config = Some(RuntimeConfig::Runtime(runtime));
self
}
/// Create a single-threaded runtime pinned to a specific CPU core
pub fn with_cpu_pin(mut self, cpu_id: usize) -> Self {
self.runtime_config = Some(RuntimeConfig::CpuPin(cpu_id));
self
}
/// Use a pre-bound TcpListener
///
/// This is useful for tests where you want to bind to port 0 and avoid port races.
/// When provided, the bind_addr should still be set (for logging/debugging purposes).
pub fn listener(mut self, listener: Option<std::net::TcpListener>) -> Self {
self.listener = listener;
self
}
/// Build the TcpListener
pub fn build(self) -> Result<TcpListener> {
let bind_addr = self
.bind_addr
.ok_or_else(|| anyhow::anyhow!("bind_addr is required"))?;
let adapter = self
.adapter
.ok_or_else(|| anyhow::anyhow!("adapter is required"))?;
let error_handler = self
.error_handler
.ok_or_else(|| anyhow::anyhow!("error_handler is required"))?;
let shutdown_state = self.shutdown_state.unwrap_or_default();
let runtime_config = self
.runtime_config
.unwrap_or_else(|| RuntimeConfig::Handle(Handle::current()));
Ok(TcpListener {
bind_addr,
adapter,
error_handler,
shutdown_state,
runtime_config,
listener: self.listener,
})
}
}
impl Default for TcpListenerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::make_channels;
use std::net::{IpAddr, Ipv4Addr};
struct TestErrorHandler;
impl TransportErrorHandler for TestErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
eprintln!("Test error handler: {}", error);
}
}
#[test]
fn test_builder_requires_fields() {
let result = TcpListener::builder().build();
assert!(result.is_err());
}
#[tokio::test]
async fn test_builder_with_all_fields() {
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
let (adapter, _streams) = make_channels();
let error_handler = Arc::new(TestErrorHandler);
let result = TcpListener::builder()
.bind_addr(bind_addr)
.adapter(adapter)
.error_handler(error_handler)
.build();
assert!(result.is_ok());
}
#[test]
fn test_builder_with_cpu_pin() {
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
let (adapter, _streams) = make_channels();
let error_handler = Arc::new(TestErrorHandler);
let result = TcpListener::builder()
.bind_addr(bind_addr)
.adapter(adapter)
.error_handler(error_handler)
.with_cpu_pin(0)
.build();
assert!(result.is_ok());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! TCP Transport Module
//!
//! This module provides a high-performance TCP transport implementation with:
//! - Zero-copy frame codec for minimal overhead
//! - CPU pinning support for predictable latency
//! - Frame type routing (Message, Response, Ack, Event)
//! - Graceful shutdown with proper FIN handling
//! - Keep-alive for dead connection detection
mod framing;
mod listener;
mod transport;
pub use framing::TcpFrameCodec;
pub use listener::{RuntimeConfig, TcpListener, TcpListenerBuilder};
pub use transport::{TcpTransport, TcpTransportBuilder};
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! High-performance TCP transport with single-threaded optimizations
//!
//! This implementation uses Rc+RefCell+LocalSet for maximum performance on a single CPU core.
//! All operations run on the same thread as the TCP listener for optimal cache locality.
use anyhow::{Context, Result};
use bytes::Bytes;
use dashmap::DashMap;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use crate::transport::{HealthCheckError, ShutdownState, TransportError, TransportErrorHandler};
use crate::{MessageType, PeerInfo, Transport, TransportAdapter, TransportKey, WorkerAddress};
use super::framing::TcpFrameCodec;
use super::listener::TcpListener;
/// High-performance TCP transport with lock-free concurrent access
///
/// This transport uses `DashMap` for lock-free concurrent access to connection state.
/// Tasks are spawned using `tokio::spawn` for compatibility with the `Transport` trait.
/// For single-threaded performance, run the entire transport in a `LocalSet` context.
pub struct TcpTransport {
// Identity (immutable, no wrapper needed)
key: TransportKey,
bind_addr: SocketAddr,
local_address: WorkerAddress,
// Shared mutable state with DashMap (lock-free)
peers: Arc<DashMap<crate::InstanceId, SocketAddr>>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
// Runtime handle for spawning tasks
runtime: OnceLock<tokio::runtime::Handle>,
// Shutdown coordination
cancel_token: CancellationToken,
shutdown_state: OnceLock<ShutdownState>,
// Send channel capacity for backpressure
channel_capacity: usize,
// Optional pre-bound listener (used for tests to avoid port races)
listener: Mutex<Option<std::net::TcpListener>>,
}
/// Handle to a connection's writer task
#[derive(Clone)]
struct ConnectionHandle {
tx: flume::Sender<SendTask>,
}
/// Task sent to writer task containing pre-encoded frame
struct SendTask {
msg_type: MessageType,
header: Bytes,
payload: Bytes,
on_error: Arc<dyn TransportErrorHandler>,
}
impl SendTask {
fn on_error(self, error: impl Into<String>) {
self.on_error
.on_error(self.header, self.payload, error.into());
}
}
impl TcpTransport {
/// Create a new TCP transport bound to `bind_addr` with the given transport key.
///
/// An optional pre-bound `listener` can be provided (useful for tests binding
/// to port 0). `channel_capacity` controls backpressure on per-connection
/// writer channels (default 256).
pub fn new(
bind_addr: SocketAddr,
key: TransportKey,
local_address: WorkerAddress,
channel_capacity: usize,
listener: Option<std::net::TcpListener>,
) -> Self {
Self {
key,
bind_addr,
local_address,
peers: Arc::new(DashMap::new()),
connections: Arc::new(DashMap::new()),
runtime: OnceLock::new(),
cancel_token: CancellationToken::new(),
shutdown_state: OnceLock::new(),
channel_capacity,
listener: Mutex::new(listener),
}
}
/// Optional: Pre-establish connection after registration
///
/// This can be called after `register()` to eagerly establish the TCP connection
/// instead of waiting for the first `send_message()` call.
pub fn ensure_connected(&self, instance_id: crate::InstanceId) -> Result<()> {
self.get_or_create_connection(instance_id)?;
Ok(())
}
/// Get or create a connection to a peer (lazy initialization)
fn get_or_create_connection(&self, instance_id: crate::InstanceId) -> Result<ConnectionHandle> {
// Fast path: connection already exists and is alive
if let Some(handle) = self.connections.get(&instance_id) {
if !handle.tx.is_disconnected() {
return Ok(handle.clone());
}
// Stale — drop guard before mutating the map
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
}
let rt = self.runtime.get().ok_or(TransportError::NotStarted)?;
// Atomic check-and-insert via entry API
let handle = match self.connections.entry(instance_id) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
if !entry.get().tx.is_disconnected() {
entry.get().clone()
} else {
// Stale entry — replace in-place with a fresh connection
let handle = self.create_connection(instance_id, rt)?;
entry.insert(handle.clone());
handle
}
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
let handle = self.create_connection(instance_id, rt)?;
entry.insert(handle.clone());
handle
}
};
Ok(handle)
}
/// Create a new connection handle and spawn the writer task.
fn create_connection(
&self,
instance_id: crate::InstanceId,
rt: &tokio::runtime::Handle,
) -> Result<ConnectionHandle> {
let addr = *self
.peers
.get(&instance_id)
.ok_or(TransportError::PeerNotRegistered(instance_id))?
.value();
let (tx, rx) = flume::bounded(self.channel_capacity);
let handle = ConnectionHandle { tx };
let cancel = self.cancel_token.clone();
let conns = Arc::clone(&self.connections);
rt.spawn(connection_writer_task(addr, instance_id, rx, conns, cancel));
debug!("Created new connection to {} ({})", instance_id, addr);
Ok(handle)
}
}
impl Transport for TcpTransport {
fn key(&self) -> TransportKey {
self.key.clone()
}
fn address(&self) -> WorkerAddress {
self.local_address.clone()
}
fn register(&self, peer_info: PeerInfo) -> Result<(), TransportError> {
// Get endpoint from peer's address
let endpoint = peer_info
.worker_address()
.get_entry(&self.key)
.map_err(|_| TransportError::NoEndpoint)?
.ok_or(TransportError::NoEndpoint)?;
// Parse TCP endpoint (expected format: "tcp://host:port" or "host:port")
let addr = parse_tcp_endpoint(&endpoint).map_err(|e| {
error!("Failed to parse TCP endpoint: {}", e);
TransportError::InvalidEndpoint
})?;
// Store peer address
self.peers.insert(peer_info.instance_id(), addr);
debug!("Registered peer {} at {}", peer_info.instance_id(), addr);
Ok(())
}
#[inline]
fn send_message(
&self,
instance_id: crate::InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: std::sync::Arc<dyn TransportErrorHandler>,
) {
// Convert to Bytes (one allocation each)
let header = Bytes::from(header);
let payload = Bytes::from(payload);
let send_msg = SendTask {
msg_type: message_type,
header,
payload,
on_error,
};
// Fast path: try to send on existing connection
let send_msg = match self.connections.get(&instance_id) {
Some(handle) => match handle.tx.try_send(send_msg) {
Ok(()) => return,
Err(flume::TrySendError::Full(send_msg)) => send_msg,
Err(flume::TrySendError::Disconnected(send_msg)) => {
// Drop the guard before mutating the map
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
// Fall through to slow path to create a fresh connection
send_msg
}
},
None => send_msg,
};
// Slow path: create new connection
let rt = match self.runtime.get() {
Some(rt) => rt,
None => {
send_msg.on_error("Transport not started");
return;
}
};
let handle = match self.get_or_create_connection(instance_id) {
Ok(h) => h,
Err(e) => {
send_msg.on_error(format!("Failed to create connection: {}", e));
return;
}
};
rt.spawn(async move {
if let Err(flume::SendError(send_msg)) = handle.tx.send_async(send_msg).await {
send_msg.on_error("Connection closed");
}
});
}
fn start(
&self,
_instance_id: crate::InstanceId,
channels: TransportAdapter,
rt: tokio::runtime::Handle,
) -> futures::future::BoxFuture<'_, anyhow::Result<()>> {
// Store runtime handle for use in send_message
self.runtime.set(rt.clone()).ok();
// Capture shutdown state from the adapter
self.shutdown_state
.set(channels.shutdown_state.clone())
.ok();
let bind_addr = self.bind_addr;
let shutdown_state = channels.shutdown_state.clone();
// Take ownership of the listener (if present) - we can only start once
let listener = self
.listener
.lock()
.expect("Listener mutex poisoned")
.take();
Box::pin(async move {
// Create error handler that routes to the transport error handler
struct DefaultErrorHandler;
impl TransportErrorHandler for DefaultErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
warn!("Transport error: {}", error);
}
}
// Start TCP listener
let tcp_listener = TcpListener::builder()
.bind_addr(bind_addr)
.adapter(channels)
.error_handler(std::sync::Arc::new(DefaultErrorHandler))
.shutdown_state(shutdown_state)
.listener(listener)
.build()?;
rt.spawn(async move {
if let Err(e) = tcp_listener.serve().await {
error!("TCP listener error: {}", e);
}
});
info!("TCP transport started on {}", bind_addr);
Ok(())
})
}
fn begin_drain(&self) {
// Per-frame gate in the listener handles drain — no-op here.
}
fn shutdown(&self) {
info!("Shutting down TCP transport");
// Cancel the teardown token (Phase 3) to stop the listener and connection handlers
if let Some(state) = self.shutdown_state.get() {
state.teardown_token().cancel();
}
self.cancel_token.cancel();
// Clear connections
self.connections.clear();
}
fn check_health(
&self,
instance_id: crate::InstanceId,
timeout: Duration,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), HealthCheckError>> + Send + '_>,
> {
Box::pin(async move {
// Check if we have an existing connection
let connection_exists = self.connections.contains_key(&instance_id);
if let Some(handle) = self.connections.get(&instance_id) {
// Check if the channel is still connected (socket is still live)
// If the writer task has exited (socket closed), the channel will be disconnected
if !handle.tx.is_disconnected() {
return Ok(()); // Connection is alive and healthy
}
// Channel is disconnected — drop guard and remove stale entry
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
}
// No existing connection or connection is dead - verify peer is reachable
let addr = *self
.peers
.get(&instance_id)
.ok_or(HealthCheckError::PeerNotRegistered)?
.value();
// Try to connect (and immediately drop) to verify peer is reachable
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(_stream)) => {
// Connection successful, drop immediately
// If we never had a connection before, report NeverConnected
// If we had one before that failed, report Ok (peer is reachable now)
if connection_exists {
Ok(())
} else {
Err(HealthCheckError::NeverConnected)
}
}
Ok(Err(_)) => Err(HealthCheckError::ConnectionFailed),
Err(_) => Err(HealthCheckError::Timeout),
}
})
}
}
/// Connection writer task
///
/// This task runs on the LocalSet and handles writing framed bytes to the TCP stream.
/// It receives pre-encoded frames via a flume channel and writes them to the socket.
///
/// Cleanup (draining queued messages and removing the stale map entry) always runs,
/// even if the initial TCP connect fails.
async fn connection_writer_task(
addr: SocketAddr,
instance_id: crate::InstanceId,
rx: flume::Receiver<SendTask>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
cancel_token: CancellationToken,
) -> Result<()> {
let result = connection_writer_inner(addr, instance_id, &rx, &cancel_token).await;
// Always drain queued messages and notify their error handlers.
//
// TODO: There is a tiny race between the drain finishing and `drop(rx)`:
// a sender on another thread could `try_send` successfully in that window,
// and the message would be silently dropped when rx is destroyed. Closing
// this fully would require swapping the map entry with a "poisoned" handle
// (a disconnected tx) before draining, so fast-path senders see a failure
// instead. Not worth the complexity today — at most one message is affected,
// and async senders already get `SendError` once rx is dropped.
while let Ok(msg) = rx.try_recv() {
msg.on_error("Connection closed");
}
// Drop the receiver so our sender half becomes disconnected, then remove
// the stale entry. The predicate ensures we only remove our own entry —
// a replacement connection's tx will still be connected.
drop(rx);
connections.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
debug!("Connection to {} ({}) closed", instance_id, addr);
result
}
/// Inner loop: connect, configure the socket, and send frames until the channel
/// closes or a write error occurs.
async fn connection_writer_inner(
addr: SocketAddr,
instance_id: crate::InstanceId,
rx: &flume::Receiver<SendTask>,
cancel_token: &CancellationToken,
) -> Result<()> {
debug!("Connecting to {}", addr);
let mut stream = tokio::select! {
_ = cancel_token.cancelled() => return Ok(()),
res = TcpStream::connect(addr) => res.context("connect failed")?,
};
if let Err(e) = stream.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY: {}", e);
}
let sock = socket2::SockRef::from(&stream);
if let Err(e) = sock.set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10)),
) {
warn!("Failed to set keepalive: {}", e);
}
if let Err(e) = sock.set_send_buffer_size(1_048_576) {
warn!("Failed to set send buffer size: {}", e);
}
debug!("Connected to {}", addr);
loop {
let msg = tokio::select! {
_ = cancel_token.cancelled() => break,
res = rx.recv_async() => match res {
Ok(msg) => msg,
Err(_) => break,
},
};
if let Err(e) =
TcpFrameCodec::encode_frame(&mut stream, msg.msg_type, &msg.header, &msg.payload).await
{
error!("Write error to {} ({}): {}", instance_id, addr, e);
msg.on_error(format!("Failed to write to stream: {}", e));
break;
}
}
Ok(())
}
/// Parse a TCP endpoint string into a SocketAddr
///
/// Accepts formats:
/// - "tcp://host:port"
/// - "host:port"
fn parse_tcp_endpoint(endpoint: &[u8]) -> Result<SocketAddr> {
let endpoint_str = std::str::from_utf8(endpoint).context("endpoint is not valid UTF-8")?;
// Strip "tcp://" prefix if present
let addr_str = endpoint_str.strip_prefix("tcp://").unwrap_or(endpoint_str);
// Parse as socket address
let mut addrs = addr_str
.to_socket_addrs()
.context("failed to parse socket address")?;
addrs
.next()
.ok_or_else(|| anyhow::anyhow!("no addresses resolved"))
}
/// Resolve a wildcard bind address to a routable address for advertisement.
///
/// When binding to 0.0.0.0 (IPv4 unspecified) or :: (IPv6 unspecified),
/// we need to advertise a routable address that peers can actually connect to.
///
/// For 0.0.0.0, we use 127.0.0.1 (localhost) which works for same-machine communication.
/// For ::, we use ::1 (IPv6 localhost).
///
/// In a production multi-node deployment, this should be replaced with actual
/// network interface discovery or explicit configuration.
fn resolve_advertise_address(bind_addr: SocketAddr) -> SocketAddr {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
match bind_addr.ip() {
IpAddr::V4(ip) if ip.is_unspecified() => {
// 0.0.0.0 -> 127.0.0.1 for local testing
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), bind_addr.port())
}
IpAddr::V6(ip) if ip.is_unspecified() => {
// :: -> ::1 for local testing
SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), bind_addr.port())
}
_ => {
// Already a specific address, use as-is
bind_addr
}
}
}
/// Builder for TcpTransport
pub struct TcpTransportBuilder {
bind_addr: Option<SocketAddr>,
key: Option<TransportKey>,
channel_capacity: usize,
listener: Option<std::net::TcpListener>,
}
impl TcpTransportBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
bind_addr: None,
key: None,
channel_capacity: 256,
listener: None,
}
}
/// Set the bind address
pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
self.bind_addr = Some(addr);
self
}
/// Set the transport key
pub fn key(mut self, key: TransportKey) -> Self {
self.key = Some(key);
self
}
/// Set the channel capacity for backpressure (default: 256)
pub fn channel_capacity(mut self, capacity: usize) -> Self {
self.channel_capacity = capacity;
self
}
/// Use a pre-bound TcpListener instead of binding to a specific address
///
/// This is useful for tests where you want to bind to port 0 and get an OS-assigned
/// port without creating a race condition between binding and starting the transport.
///
/// Note: This is mutually exclusive with `bind_addr()`. Using both will result in an error.
pub fn from_listener(mut self, listener: std::net::TcpListener) -> Result<Self> {
// Validate mutual exclusivity: can't use both bind_addr() and from_listener()
if self.bind_addr.is_some() {
anyhow::bail!(
"Cannot use both bind_addr() and from_listener() - they are mutually exclusive"
);
}
let addr = listener
.local_addr()
.context("Failed to get local address from listener")?;
self.bind_addr = Some(addr);
self.listener = Some(listener);
Ok(self)
}
/// Build the TcpTransport
pub fn build(self) -> Result<TcpTransport> {
let bind_addr = self
.bind_addr
.ok_or_else(|| anyhow::anyhow!("bind_addr is required"))?;
let key = self.key.unwrap_or_else(|| TransportKey::from("tcp"));
// Resolve advertise address (handle 0.0.0.0 -> 127.0.0.1 for local testing)
let advertise_addr = resolve_advertise_address(bind_addr);
let local_endpoint = format!("tcp://{}", advertise_addr);
let mut addr_builder = crate::address::WorkerAddressBuilder::new();
addr_builder.add_entry(key.clone(), local_endpoint.as_bytes().to_vec())?;
let local_address = addr_builder.build()?;
Ok(TcpTransport::new(
bind_addr,
key,
local_address,
self.channel_capacity,
self.listener,
))
}
}
impl Default for TcpTransportBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::address::WorkerAddressBuilder;
use std::sync::atomic::{AtomicUsize, Ordering};
use velo_common::PeerInfo;
/// Error handler that discards errors (for tests that don't need to track them).
struct NullErrorHandler;
impl TransportErrorHandler for NullErrorHandler {
fn on_error(&self, _: Bytes, _: Bytes, _: String) {}
}
/// Error handler that counts errors (for tests that verify error routing).
struct TrackingErrorHandler {
count: AtomicUsize,
}
impl TrackingErrorHandler {
fn new() -> Self {
Self {
count: AtomicUsize::new(0),
}
}
fn error_count(&self) -> usize {
self.count.load(Ordering::SeqCst)
}
}
impl TransportErrorHandler for TrackingErrorHandler {
fn on_error(&self, _: Bytes, _: Bytes, _: String) {
self.count.fetch_add(1, Ordering::SeqCst);
}
}
/// Build a `PeerInfo` whose TCP endpoint points at `addr`.
fn make_tcp_peer(addr: SocketAddr) -> PeerInfo {
let instance_id = crate::InstanceId::new_v4();
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("tcp", format!("tcp://{}", addr).into_bytes())
.unwrap();
PeerInfo::new(instance_id, builder.build().unwrap())
}
/// Build a `TcpTransport` with its runtime set, bound to a real listener.
/// Returns `(transport, listener_addr)`.
fn make_transport() -> (TcpTransport, SocketAddr) {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let transport = TcpTransportBuilder::new()
.from_listener(listener)
.unwrap()
.build()
.unwrap();
// Set the runtime handle so `get_or_create_connection` can spawn tasks.
transport
.runtime
.set(tokio::runtime::Handle::current())
.ok();
(transport, addr)
}
/// Insert a stale `ConnectionHandle` into the transport's connections map.
/// A "stale" handle is one whose receiver has been dropped.
fn insert_stale_handle(transport: &TcpTransport, instance_id: crate::InstanceId) {
let (tx, _rx) = flume::bounded::<SendTask>(1);
// Drop _rx immediately so tx.is_disconnected() == true
transport
.connections
.insert(instance_id, ConnectionHandle { tx });
}
#[test]
fn test_parse_tcp_endpoint() {
// With tcp:// prefix
let addr = parse_tcp_endpoint(b"tcp://127.0.0.1:5555").unwrap();
assert_eq!(addr.port(), 5555);
// Without prefix
let addr = parse_tcp_endpoint(b"127.0.0.1:6666").unwrap();
assert_eq!(addr.port(), 6666);
// Invalid
assert!(parse_tcp_endpoint(b"invalid").is_err());
}
#[test]
fn test_builder_requires_bind_addr() {
let result = TcpTransportBuilder::new().build();
assert!(result.is_err());
}
#[test]
fn test_builder_with_bind_addr() {
let addr = "127.0.0.1:0".parse().unwrap();
let result = TcpTransportBuilder::new().bind_addr(addr).build();
assert!(result.is_ok());
}
#[test]
fn test_builder_with_listener() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let result = TcpTransportBuilder::new().from_listener(listener);
assert!(result.is_ok());
let result = result.unwrap().build();
assert!(result.is_ok());
}
#[test]
fn test_builder_bind_addr_and_listener_mutually_exclusive() {
let addr = "127.0.0.1:0".parse().unwrap();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let result = TcpTransportBuilder::new()
.bind_addr(addr)
.from_listener(listener);
assert!(result.is_err());
let err_msg = format!("{}", result.err().unwrap());
assert!(err_msg.contains("mutually exclusive"));
}
#[test]
fn test_resolve_advertise_address_ipv4_unspecified() {
use std::net::{IpAddr, Ipv4Addr};
// 0.0.0.0 should resolve to 127.0.0.1
let bind_addr: SocketAddr = "0.0.0.0:12345".parse().unwrap();
let resolved = resolve_advertise_address(bind_addr);
assert_eq!(resolved.ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
assert_eq!(resolved.port(), 12345);
// Already specific address should remain unchanged
let specific: SocketAddr = "192.168.1.100:8080".parse().unwrap();
let resolved = resolve_advertise_address(specific);
assert_eq!(resolved, specific);
}
#[test]
fn test_resolve_advertise_address_ipv6_unspecified() {
use std::net::{IpAddr, Ipv6Addr};
// :: should resolve to ::1
let bind_addr: SocketAddr = "[::]:12345".parse().unwrap();
let resolved = resolve_advertise_address(bind_addr);
assert_eq!(resolved.ip(), IpAddr::V6(Ipv6Addr::LOCALHOST));
assert_eq!(resolved.port(), 12345);
// Already specific IPv6 address should remain unchanged
let specific: SocketAddr = "[::1]:8080".parse().unwrap();
let resolved = resolve_advertise_address(specific);
assert_eq!(resolved, specific);
}
#[tokio::test]
async fn test_get_or_create_connection_replaces_stale_handle() {
let (transport, _our_addr) = make_transport();
// Start a listener that the transport can connect to
let peer_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let peer_addr = peer_listener.local_addr().unwrap();
let peer = make_tcp_peer(peer_addr);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert a stale handle
insert_stale_handle(&transport, iid);
assert!(
transport
.connections
.get(&iid)
.unwrap()
.tx
.is_disconnected()
);
// get_or_create_connection should replace the stale handle with a live one
let handle = transport.get_or_create_connection(iid).unwrap();
assert!(!handle.tx.is_disconnected());
// The map entry should also be live
let entry = transport.connections.get(&iid).unwrap();
assert!(!entry.tx.is_disconnected());
}
#[tokio::test]
async fn test_check_health_removes_stale_entry() {
let (transport, _our_addr) = make_transport();
// Start a listener so the peer is "reachable"
let peer_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let peer_addr = peer_listener.local_addr().unwrap();
let peer = make_tcp_peer(peer_addr);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert stale handle — simulates a dead writer task
insert_stale_handle(&transport, iid);
assert!(transport.connections.contains_key(&iid));
// check_health should remove the stale entry and verify the peer is reachable
let result = transport.check_health(iid, Duration::from_secs(2)).await;
// Stale entry should be gone
assert!(!transport.connections.contains_key(&iid));
// Since there WAS a previous connection entry, check_health returns Ok
// (the peer is reachable via our test listener)
assert!(result.is_ok());
}
#[tokio::test]
async fn test_writer_task_cleans_up_on_write_error() {
// Bind a listener, accept once, then drop everything to cause a write error
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let iid = crate::InstanceId::new_v4();
let (tx, rx) = flume::bounded::<SendTask>(8);
let connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>> =
Arc::new(DashMap::new());
connections.insert(iid, ConnectionHandle { tx: tx.clone() });
let conns = Arc::clone(&connections);
let cancel = CancellationToken::new();
// Spawn the writer task
let writer = tokio::spawn(connection_writer_task(addr, iid, rx, conns, cancel));
// Accept the connection, then immediately drop it + the listener
let (stream, _) = listener.accept().await.unwrap();
drop(stream);
drop(listener);
// Send a message — the writer should hit a broken-pipe error
tx.send(SendTask {
msg_type: MessageType::Message,
header: Bytes::from_static(b"hdr"),
payload: Bytes::from_static(b"pay"),
on_error: Arc::new(NullErrorHandler),
})
.unwrap();
// Wait for writer task to finish
let _ = writer.await;
// The writer should have removed the stale entry from the map
assert!(
!connections.contains_key(&iid),
"writer task should clean up its DashMap entry on write error"
);
}
#[tokio::test]
async fn test_send_message_does_not_fail_on_stale_handle() {
let (transport, _our_addr) = make_transport();
// Start a listener that accepts connections (simulates a healthy peer)
let peer_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let peer_addr = peer_listener.local_addr().unwrap();
let peer = make_tcp_peer(peer_addr);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert a stale handle
insert_stale_handle(&transport, iid);
// send_message should detect the stale handle and create a new one,
// NOT immediately call on_error
let error_handler = Arc::new(TrackingErrorHandler::new());
transport.send_message(
iid,
b"test-header".to_vec(),
b"test-payload".to_vec(),
MessageType::Message,
error_handler.clone(),
);
// Accept the connection that the new writer task will establish
let (mut stream, _) = peer_listener.accept().await.unwrap();
// Read the framed message from the stream to confirm delivery
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 256];
// Give the async writer a moment to flush the frame
let n = tokio::time::timeout(Duration::from_secs(2), stream.read(&mut buf))
.await
.expect("timed out waiting for data")
.expect("read error");
assert!(n > 0, "expected data from the writer task");
// No errors should have been reported
assert_eq!(
error_handler.error_count(),
0,
"send_message should retry on stale handle, not fail"
);
// The connections map should now contain a live handle
let entry = transport.connections.get(&iid).unwrap();
assert!(
!entry.tx.is_disconnected(),
"stale handle should have been replaced with a live one"
);
}
#[tokio::test]
async fn test_writer_task_drains_on_connect_failure() {
// Use an address where nothing is listening so connect will fail.
// Binding then immediately dropping gives us a port that is guaranteed closed.
let tmp = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = tmp.local_addr().unwrap();
drop(tmp);
let iid = crate::InstanceId::new_v4();
let (tx, rx) = flume::bounded::<SendTask>(8);
let connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>> =
Arc::new(DashMap::new());
connections.insert(iid, ConnectionHandle { tx: tx.clone() });
// Queue a message *before* the writer task even starts — this simulates
// the race between create_connection returning and connect completing.
let error_handler = Arc::new(TrackingErrorHandler::new());
tx.send(SendTask {
msg_type: MessageType::Message,
header: Bytes::from_static(b"hdr"),
payload: Bytes::from_static(b"pay"),
on_error: error_handler.clone(),
})
.unwrap();
let conns = Arc::clone(&connections);
let cancel = CancellationToken::new();
let writer = tokio::spawn(connection_writer_task(addr, iid, rx, conns, cancel));
let _ = writer.await;
assert_eq!(
error_handler.error_count(),
1,
"queued message should have its on_error called when connect fails"
);
assert!(
!connections.contains_key(&iid),
"writer task should clean up its DashMap entry on connect failure"
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use bytes::Bytes;
use futures::future::BoxFuture;
use crate::{InstanceId, PeerInfo, TransportKey, WorkerAddress};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::{sync::Arc, time::Duration};
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
/// Errors returned by individual [`Transport`] implementations.
#[derive(thiserror::Error, Debug)]
pub enum TransportError {
/// The peer's [`WorkerAddress`] does not contain an entry for this transport.
#[error("No endpoint found for transport")]
NoEndpoint,
/// The endpoint string could not be parsed (malformed URL, invalid address).
#[error("Invalid endpoint format")]
InvalidEndpoint,
/// The target peer was never registered with this transport.
#[error("Peer not registered: {0}")]
PeerNotRegistered(InstanceId),
/// The transport has not been started yet (no runtime handle).
#[error("Transport not started")]
NotStarted,
/// No responders available for the peer (e.g. NATS request with no subscriber).
#[error("No responders for peer")]
NoResponders,
}
/// Error type specific to health check operations
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum HealthCheckError {
/// The peer was never registered with this transport.
#[error("Peer not registered with transport")]
PeerNotRegistered,
/// The transport has not been started yet.
#[error("Transport not started")]
TransportNotStarted,
/// The peer is registered but no connection has ever been established.
#[error("Connection never established to peer")]
NeverConnected,
/// An existing connection is unhealthy or the peer is unreachable.
#[error("Connection failed or peer unreachable")]
ConnectionFailed,
/// The health check exceeded the specified timeout.
#[error("Health check timed out")]
Timeout,
}
/// Shared shutdown coordinator for graceful multi-phase shutdown.
///
/// **Phases**:
/// 1. **Gate** — `begin_drain()` flips the draining flag; transports reject new inbound requests.
/// 2. **Drain** — `wait_for_drain()` blocks until all in-flight guards are dropped.
/// 3. **Teardown** — `teardown_token().cancel()` kills listeners and writer tasks.
///
/// Hot-path cost: a single `AtomicBool::load(Relaxed)` per frame to check `is_draining()`.
#[derive(Clone)]
pub struct ShutdownState {
inner: Arc<ShutdownStateInner>,
}
struct ShutdownStateInner {
draining: AtomicBool,
in_flight: AtomicUsize,
drain_complete: Notify,
teardown_token: CancellationToken,
}
impl ShutdownState {
/// Create a new shutdown state. Not draining, zero in-flight.
pub fn new() -> Self {
Self {
inner: Arc::new(ShutdownStateInner {
draining: AtomicBool::new(false),
in_flight: AtomicUsize::new(0),
drain_complete: Notify::new(),
teardown_token: CancellationToken::new(),
}),
}
}
/// Returns `true` if drain has been initiated (Phase 1).
///
/// Uses `Relaxed` ordering — safe for the hot-path gate check because
/// the flag is monotonic (false → true, never reset).
#[inline]
pub fn is_draining(&self) -> bool {
self.inner.draining.load(Ordering::Relaxed)
}
/// Begin Phase 1: flip the draining flag. Idempotent.
pub fn begin_drain(&self) {
self.inner.draining.store(true, Ordering::Release);
}
/// Acquire an in-flight guard. The guard increments the counter on creation
/// and decrements it on drop. Use this to track requests that are being processed.
///
/// Guards are still acquirable after `begin_drain()` — this is intentional
/// so that already-accepted work can be tracked.
pub fn acquire(&self) -> InFlightGuard {
self.inner.in_flight.fetch_add(1, Ordering::AcqRel);
InFlightGuard {
inner: self.inner.clone(),
}
}
/// Current number of in-flight requests. Primarily for testing/debugging.
pub fn in_flight_count(&self) -> usize {
self.inner.in_flight.load(Ordering::Acquire)
}
/// Wait until in-flight count reaches zero. Returns immediately if already zero.
pub async fn wait_for_drain(&self) {
loop {
if self.inner.in_flight.load(Ordering::Acquire) == 0 {
return;
}
self.inner.drain_complete.notified().await;
}
}
/// Get the Phase 3 teardown token. Cancel this to kill listeners/writers.
pub fn teardown_token(&self) -> &CancellationToken {
&self.inner.teardown_token
}
}
impl Default for ShutdownState {
fn default() -> Self {
Self::new()
}
}
/// RAII guard that decrements the in-flight counter on drop.
pub struct InFlightGuard {
inner: Arc<ShutdownStateInner>,
}
impl InFlightGuard {
/// Explicitly complete this guard (equivalent to dropping it).
pub fn complete(self) {
// Drop impl handles the decrement
}
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
let prev = self.inner.in_flight.fetch_sub(1, Ordering::AcqRel);
// If we just decremented to 0, notify waiters
if prev == 1 {
self.inner.drain_complete.notify_waiters();
}
}
}
/// Policy for how long to wait during the drain phase.
#[derive(Debug, Clone)]
pub enum ShutdownPolicy {
/// Wait indefinitely for all in-flight requests to complete.
WaitForever,
/// Wait up to the given duration, then force teardown.
Timeout(Duration),
}
/// Abstraction over a single message transport (TCP, HTTP, NATS, gRPC, UCX).
///
/// Implementations handle peer registration, message sending, listener lifecycle,
/// health checking, and graceful shutdown. The trait is object-safe so transports
/// can be stored as `Arc<dyn Transport>`.
pub trait Transport: Send + Sync {
/// Unique key identifying this transport (e.g. `"tcp"`, `"grpc"`).
fn key(&self) -> TransportKey;
/// The [`WorkerAddress`] fragment advertised by this transport.
fn address(&self) -> WorkerAddress;
/// Register a remote peer, extracting its endpoint from [`PeerInfo`].
fn register(&self, peer_info: PeerInfo) -> Result<(), TransportError>;
/// Sends an active message to the remote instance
fn send_message(
&self,
instance_id: InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: Arc<dyn TransportErrorHandler>,
);
/// Start the transport (bind listener, spawn tasks) for the given instance.
fn start(
&self,
instance_id: InstanceId,
channels: TransportAdapter,
rt: tokio::runtime::Handle,
) -> BoxFuture<'_, anyhow::Result<()>>;
/// Tear down the transport, cancelling all tasks and closing connections.
fn shutdown(&self);
/// Begin draining: reject new inbound requests while allowing responses.
///
/// Default implementation is a no-op. Transports that need per-frame
/// gating (e.g., unsubscribing from NATS subjects) should override this.
fn begin_drain(&self) {}
/// Check if a registered peer is reachable and healthy
///
/// Returns Ok(()) if peer responds to health check within timeout.
/// Different transports implement this differently:
/// - NATS: request/reply to health subject
/// - TCP: check existing connection or attempt new connection
/// - HTTP: HEAD request to health endpoint
/// - UCX: endpoint status check
///
/// # Errors
/// - `PeerNotRegistered`: Peer was never registered with this transport
/// - `TransportNotStarted`: Transport hasn't been started yet
/// - `NeverConnected`: Peer is registered but no connection has been established
/// - `ConnectionFailed`: Connection exists/existed but is currently unhealthy or unreachable
/// - `Timeout`: Health check took longer than the specified timeout
fn check_health(
&self,
instance_id: InstanceId,
timeout: Duration,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), HealthCheckError>> + Send + '_>,
>;
}
/// Callback trait invoked when a transport fails to deliver a message.
///
/// The original `header` and `payload` are returned so higher layers can
/// retry or log the failure.
pub trait TransportErrorHandler: Send + Sync {
/// Called when message delivery fails. Receives the original data and error description.
fn on_error(&self, header: Bytes, payload: Bytes, error: String);
}
/// Message type discriminator for routing frames to appropriate streams
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageType {
#[allow(missing_docs)]
Message = 0,
#[allow(missing_docs)]
Response = 1,
#[allow(missing_docs)]
Ack = 2,
#[allow(missing_docs)]
Event = 3,
/// Sent back to a peer when we are draining and cannot accept new messages.
/// The original request header is echoed back for correlation.
ShuttingDown = 4,
}
impl MessageType {
/// Try to convert a u8 to a MessageType
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(MessageType::Message),
1 => Some(MessageType::Response),
2 => Some(MessageType::Ack),
3 => Some(MessageType::Event),
4 => Some(MessageType::ShuttingDown),
_ => None,
}
}
/// Convert MessageType to u8
pub fn as_u8(self) -> u8 {
self as u8
}
}
/// Sender-side handle given to transports for routing inbound frames.
///
/// Each transport receives a clone of this adapter during [`Transport::start`]
/// and uses it to forward decoded `(header, payload)` pairs to the appropriate
/// stream based on [`MessageType`].
#[derive(Clone)]
pub struct TransportAdapter {
/// Channel for inbound [`MessageType::Message`] frames.
pub message_stream: flume::Sender<(Bytes, Bytes)>,
/// Channel for inbound [`MessageType::Response`] and [`MessageType::ShuttingDown`] frames.
pub response_stream: flume::Sender<(Bytes, Bytes)>,
/// Channel for inbound [`MessageType::Ack`] and [`MessageType::Event`] frames.
pub event_stream: flume::Sender<(Bytes, Bytes)>,
/// Shared shutdown coordinator for drain-aware routing.
pub shutdown_state: ShutdownState,
}
/// Receiver-side handle for consuming inbound frames from all transports.
///
/// Returned by [`make_channels`] alongside the corresponding [`TransportAdapter`].
/// Higher layers pull `(header, payload)` pairs from these channels.
pub struct DataStreams {
/// Receiver for inbound message frames.
pub message_stream: flume::Receiver<(Bytes, Bytes)>,
/// Receiver for inbound response and shutting-down frames.
pub response_stream: flume::Receiver<(Bytes, Bytes)>,
/// Receiver for inbound ack and event frames.
pub event_stream: flume::Receiver<(Bytes, Bytes)>,
/// Shared shutdown coordinator.
pub shutdown_state: ShutdownState,
}
type DataStreamTuple = (
flume::Receiver<(Bytes, Bytes)>,
flume::Receiver<(Bytes, Bytes)>,
flume::Receiver<(Bytes, Bytes)>,
);
impl DataStreams {
/// Destructure into the three raw receivers `(message, response, event)`.
pub fn into_parts(self) -> DataStreamTuple {
(self.message_stream, self.response_stream, self.event_stream)
}
/// Receive a message with an in-flight guard for drain tracking.
///
/// Returns `(header, payload, guard)`. The guard keeps the in-flight counter
/// incremented until it is dropped or `complete()` is called.
pub async fn recv_message_tracked(
&self,
) -> Result<(Bytes, Bytes, InFlightGuard), flume::RecvError> {
let (header, payload) = self.message_stream.recv_async().await?;
let guard = self.shutdown_state.acquire();
Ok((header, payload, guard))
}
}
/// Create a matched pair of [`TransportAdapter`] (sender) and [`DataStreams`] (receiver).
///
/// Both sides share the same [`ShutdownState`] so drain coordination is automatic.
pub fn make_channels() -> (TransportAdapter, DataStreams) {
let shutdown_state = ShutdownState::new();
let (message_tx, message_rx) = flume::unbounded();
let (response_tx, response_rx) = flume::unbounded();
let (event_tx, event_rx) = flume::unbounded();
(
TransportAdapter {
message_stream: message_tx,
response_stream: response_tx,
event_stream: event_tx,
shutdown_state: shutdown_state.clone(),
},
DataStreams {
message_stream: message_rx,
response_stream: response_rx,
event_stream: event_rx,
shutdown_state,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, timeout};
#[test]
fn test_shutdown_state_initial() {
let state = ShutdownState::new();
assert!(!state.is_draining());
assert_eq!(state.in_flight_count(), 0);
}
#[test]
fn test_begin_drain_flips_flag() {
let state = ShutdownState::new();
state.begin_drain();
assert!(state.is_draining());
}
#[test]
fn test_begin_drain_idempotent() {
let state = ShutdownState::new();
state.begin_drain();
state.begin_drain();
assert!(state.is_draining());
}
#[test]
fn test_acquire_increments_inflight() {
let state = ShutdownState::new();
let _g1 = state.acquire();
assert_eq!(state.in_flight_count(), 1);
let _g2 = state.acquire();
assert_eq!(state.in_flight_count(), 2);
}
#[test]
fn test_guard_drop_decrements_inflight() {
let state = ShutdownState::new();
let g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
drop(g);
assert_eq!(state.in_flight_count(), 0);
}
#[test]
fn test_guard_complete_decrements() {
let state = ShutdownState::new();
let g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
g.complete();
assert_eq!(state.in_flight_count(), 0);
}
#[tokio::test]
async fn test_wait_for_drain_immediate() {
let state = ShutdownState::new();
// Should complete immediately since in_flight is 0
timeout(Duration::from_millis(100), state.wait_for_drain())
.await
.expect("wait_for_drain should complete immediately when in_flight is 0");
}
#[tokio::test]
async fn test_wait_for_drain_blocks_then_completes() {
let state = ShutdownState::new();
let guard = state.acquire();
let state_clone = state.clone();
let handle = tokio::spawn(async move {
state_clone.wait_for_drain().await;
});
// Give the waiter time to park
sleep(Duration::from_millis(50)).await;
assert!(!handle.is_finished());
// Drop guard → should unblock
drop(guard);
timeout(Duration::from_millis(100), handle)
.await
.expect("should complete after guard drop")
.unwrap();
}
#[tokio::test]
async fn test_multiple_guards_concurrent() {
let state = ShutdownState::new();
let guards: Vec<_> = (0..10).map(|_| state.acquire()).collect();
assert_eq!(state.in_flight_count(), 10);
let state_clone = state.clone();
let handle = tokio::spawn(async move {
state_clone.wait_for_drain().await;
});
// Drop all guards
drop(guards);
timeout(Duration::from_millis(100), handle)
.await
.expect("should complete after all guards drop")
.unwrap();
assert_eq!(state.in_flight_count(), 0);
}
#[tokio::test]
async fn test_drain_with_zero_inflight() {
let state = ShutdownState::new();
state.begin_drain();
// Should complete immediately
timeout(Duration::from_millis(100), state.wait_for_drain())
.await
.expect("should complete immediately with zero in-flight");
}
#[test]
fn test_acquire_works_after_drain() {
let state = ShutdownState::new();
state.begin_drain();
let _g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
}
#[test]
fn test_guard_drop_during_panic() {
let state = ShutdownState::new();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
panic!("intentional panic");
}));
assert!(result.is_err());
// Guard's Drop should have fired even during unwind
assert_eq!(state.in_flight_count(), 0);
}
#[test]
fn test_shutting_down_from_u8() {
assert_eq!(MessageType::from_u8(4), Some(MessageType::ShuttingDown));
}
#[test]
fn test_shutting_down_as_u8() {
assert_eq!(MessageType::ShuttingDown.as_u8(), 4);
}
#[test]
fn test_unknown_message_type_still_none() {
assert_eq!(MessageType::from_u8(5), None);
assert_eq!(MessageType::from_u8(255), None);
}
#[test]
fn test_make_channels_includes_shutdown_state() {
let (adapter, streams) = make_channels();
// Both sides should share the same ShutdownState (via Arc)
assert!(!adapter.shutdown_state.is_draining());
assert!(!streams.shutdown_state.is_draining());
// Mutating one should be visible through the other
adapter.shutdown_state.begin_drain();
assert!(streams.shutdown_state.is_draining());
}
#[tokio::test]
async fn test_recv_message_tracked_returns_guard() {
let (adapter, streams) = make_channels();
// Send a message through the adapter
adapter
.message_stream
.send_async((
bytes::Bytes::from_static(b"hdr"),
bytes::Bytes::from_static(b"pay"),
))
.await
.unwrap();
// Receive with tracking
let (header, payload, guard) = streams.recv_message_tracked().await.unwrap();
assert_eq!(&header[..], b"hdr");
assert_eq!(&payload[..], b"pay");
assert_eq!(streams.shutdown_state.in_flight_count(), 1);
// Drop guard
drop(guard);
assert_eq!(streams.shutdown_state.in_flight_count(), 0);
}
#[test]
fn test_shutdown_state_clone_shares_inner() {
let s1 = ShutdownState::new();
let s2 = s1.clone();
s1.begin_drain();
assert!(s2.is_draining());
let _g = s1.acquire();
assert_eq!(s2.in_flight_count(), 1);
}
#[test]
fn test_teardown_token() {
let state = ShutdownState::new();
assert!(!state.teardown_token().is_cancelled());
state.teardown_token().cancel();
assert!(state.teardown_token().is_cancelled());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! UDS listener for ActiveMessage transport
//!
//! Mirrors `tcp/listener.rs` but uses `UnixListener`/`UnixStream`.
//! Reuses `TcpFrameCodec` for framing. Supports drain-aware frame handling
//! via `ShutdownState`.
use anyhow::{Context, Result};
use bytes::Bytes;
use futures::StreamExt;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::net::UnixListener as TokioUnixListener;
use tokio::net::UnixStream;
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};
use crate::{MessageType, ShutdownState, TransportAdapter, TransportErrorHandler};
use crate::tcp::TcpFrameCodec;
/// UDS listener for ActiveMessage transport
///
/// Accepts incoming Unix domain socket connections and routes decoded frames
/// to the appropriate transport streams. Supports graceful drain: during drain,
/// new `Message` frames are rejected with a `ShuttingDown` response while
/// `Response`/`Event`/`Ack` frames continue to flow.
pub struct UdsListener {
socket_path: PathBuf,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
}
/// UDS listener that has been bound to a socket path, ready to accept connections.
///
/// Created by [`UdsListener::bind`]. Holding this value proves the OS-level bind
/// succeeded, so callers can detect failures before spawning a task.
pub struct BoundUdsListener {
socket_path: PathBuf,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
listener: TokioUnixListener,
}
impl UdsListener {
/// Create a new builder for UdsListener
pub fn builder() -> UdsListenerBuilder {
UdsListenerBuilder::new()
}
/// Bind to the socket path and return a [`BoundUdsListener`] ready to serve.
///
/// `TokioUnixListener::bind` is synchronous, so this method is also
/// synchronous. Callers that need to propagate bind failures before spawning
/// a task should call `bind()` first, then spawn `bound.serve()`.
pub fn bind(self) -> Result<BoundUdsListener> {
let listener = TokioUnixListener::bind(&self.socket_path)
.with_context(|| format!("Failed to bind UDS listener to {:?}", self.socket_path))?;
info!("UDS listener bound to {:?}", self.socket_path);
Ok(BoundUdsListener {
socket_path: self.socket_path,
adapter: self.adapter,
error_handler: self.error_handler,
shutdown_state: self.shutdown_state,
listener,
})
}
/// Convenience shim: bind and serve in one call.
pub async fn serve(self) -> Result<()> {
self.bind()?.serve().await
}
/// Handle a single UDS connection
async fn handle_connection(
stream: UnixStream,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
) -> Result<()> {
debug!("Configuring UDS connection");
// Create framed stream with zero-copy codec (same as TCP)
let mut framed = Framed::new(stream, TcpFrameCodec::new());
let teardown_token = shutdown_state.teardown_token().clone();
debug!("UDS connection ready for frames");
loop {
tokio::select! {
frame_result = framed.next() => {
match frame_result {
Some(Ok((msg_type, header, payload))) => {
// During drain: reject new Message frames with ShuttingDown,
// but always pass through Response/Ack/Event frames.
if shutdown_state.is_draining() && msg_type == MessageType::Message {
debug!(
"Rejecting Message frame during drain (sending ShuttingDown)"
);
// Echo original header back for correlation, empty payload
if let Err(e) = TcpFrameCodec::encode_frame(
framed.get_mut(),
MessageType::ShuttingDown,
&header,
&[],
)
.await
{
warn!(
"Failed to send ShuttingDown frame: {}",
e
);
}
continue;
}
if let Err(e) = Self::route_frame(
msg_type,
header,
payload,
&adapter,
&error_handler,
)
.await
{
warn!(
"Failed to route {:?} frame from UDS: {}",
msg_type, e
);
}
}
Some(Err(e)) => {
error!("Frame decode error from UDS: {}", e);
break;
}
None => {
debug!("UDS connection closed gracefully");
break;
}
}
}
_ = teardown_token.cancelled() => {
debug!("UDS connection handler torn down");
break;
}
}
}
Ok(())
}
/// Route a decoded frame to the appropriate stream
async fn route_frame(
msg_type: MessageType,
header: Bytes,
payload: Bytes,
adapter: &TransportAdapter,
error_handler: &Arc<dyn TransportErrorHandler>,
) -> Result<()> {
let sender = match msg_type {
MessageType::Message => &adapter.message_stream,
MessageType::Response => &adapter.response_stream,
MessageType::Ack | MessageType::Event => &adapter.event_stream,
MessageType::ShuttingDown => {
// ShuttingDown is an outbound-only frame type; receiving it here
// means a remote peer rejected our request. Route to the response
// stream so higher layers can handle the rejection via correlation.
&adapter.response_stream
}
};
match sender.send_async((header, payload)).await {
Ok(_) => Ok(()),
Err(e) => {
error_handler.on_error(e.0.0, e.0.1, format!("Failed to route {:?}", msg_type));
Err(anyhow::anyhow!("Failed to send to stream"))
}
}
}
}
impl BoundUdsListener {
/// Accept connections until the teardown token is cancelled.
pub async fn serve(self) -> Result<()> {
let teardown_token = self.shutdown_state.teardown_token().clone();
loop {
tokio::select! {
accept_result = self.listener.accept() => {
match accept_result {
Ok((stream, _addr)) => {
debug!("Accepted UDS connection");
let adapter = self.adapter.clone();
let error_handler = self.error_handler.clone();
let shutdown_state = self.shutdown_state.clone();
tokio::spawn(async move {
if let Err(e) = UdsListener::handle_connection(
stream,
adapter,
error_handler,
shutdown_state,
)
.await
{
warn!("Error handling UDS connection: {}", e);
}
});
}
Err(e) => {
error!("Failed to accept UDS connection: {}", e);
}
}
}
_ = teardown_token.cancelled() => {
info!("UDS listener shutting down (teardown)");
break;
}
}
}
// Clean up socket file
std::fs::remove_file(&self.socket_path).ok();
Ok(())
}
}
/// Builder for UdsListener
pub struct UdsListenerBuilder {
socket_path: Option<PathBuf>,
adapter: Option<TransportAdapter>,
error_handler: Option<Arc<dyn TransportErrorHandler>>,
shutdown_state: Option<ShutdownState>,
}
impl UdsListenerBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
socket_path: None,
adapter: None,
error_handler: None,
shutdown_state: None,
}
}
/// Set the socket path
pub fn socket_path(mut self, path: PathBuf) -> Self {
self.socket_path = Some(path);
self
}
/// Set the transport adapter
pub fn adapter(mut self, adapter: TransportAdapter) -> Self {
self.adapter = Some(adapter);
self
}
/// Set the error handler
pub fn error_handler(mut self, handler: Arc<dyn TransportErrorHandler>) -> Self {
self.error_handler = Some(handler);
self
}
/// Set the shutdown state for graceful drain coordination
pub fn shutdown_state(mut self, state: ShutdownState) -> Self {
self.shutdown_state = Some(state);
self
}
/// Build the UdsListener
pub fn build(self) -> Result<UdsListener> {
let socket_path = self
.socket_path
.ok_or_else(|| anyhow::anyhow!("socket_path is required"))?;
let adapter = self
.adapter
.ok_or_else(|| anyhow::anyhow!("adapter is required"))?;
let error_handler = self
.error_handler
.ok_or_else(|| anyhow::anyhow!("error_handler is required"))?;
let shutdown_state = self.shutdown_state.unwrap_or_default();
Ok(UdsListener {
socket_path,
adapter,
error_handler,
shutdown_state,
})
}
}
impl Default for UdsListenerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::make_channels;
struct TestErrorHandler;
impl TransportErrorHandler for TestErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
eprintln!("Test error handler: {}", error);
}
}
#[test]
fn test_builder_requires_fields() {
let result = UdsListener::builder().build();
assert!(result.is_err());
}
#[tokio::test]
async fn test_builder_with_all_fields() {
let (adapter, _streams) = make_channels();
let error_handler = Arc::new(TestErrorHandler);
let result = UdsListener::builder()
.socket_path(PathBuf::from("/tmp/test-uds-listener.sock"))
.adapter(adapter)
.error_handler(error_handler)
.build();
assert!(result.is_ok());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Unix Domain Socket (UDS) Transport Module
//!
//! This module provides a UDS transport implementation that mirrors the TCP transport
//! but uses Unix domain sockets instead of TCP connections. It reuses the same
//! zero-copy frame codec (`TcpFrameCodec`) since the framing protocol is transport-agnostic.
//!
//! Key differences from TCP:
//! - Uses `PathBuf` instead of `SocketAddr`
//! - Uses `UnixStream`/`UnixListener` instead of `TcpStream`/`TcpListener`
//! - No TCP-specific options (nodelay, keepalive, CPU pinning)
//! - Endpoint format: `uds:///path/to/socket`
//!
//! This transport is ideal for same-host communication (e.g., daemon-to-container via
//! bind-mounted sockets), avoiding the overhead of the TCP/IP stack entirely.
mod listener;
mod transport;
pub use listener::{UdsListener, UdsListenerBuilder};
pub use transport::{UdsTransport, UdsTransportBuilder};
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! UDS transport implementation
//!
//! Structural mirror of the TCP transport (`tcp/transport.rs`), replacing
//! `TcpStream`/`TcpListener` with `UnixStream`/`UnixListener`.
//! Reuses `TcpFrameCodec` for framing since it operates on any `AsyncRead + AsyncWrite`.
use anyhow::{Context, Result};
use bytes::Bytes;
use dashmap::DashMap;
use std::os::unix::fs::FileTypeExt;
use std::path::{Path, PathBuf};
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use tokio::net::UnixStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use crate::transport::{HealthCheckError, ShutdownState, TransportError, TransportErrorHandler};
use crate::{MessageType, PeerInfo, Transport, TransportAdapter, TransportKey, WorkerAddress};
use super::listener::UdsListener;
use crate::tcp::TcpFrameCodec;
/// UDS transport with lock-free concurrent access
///
/// Mirrors `TcpTransport` but uses Unix domain sockets.
pub struct UdsTransport {
key: TransportKey,
socket_path: PathBuf,
local_address: WorkerAddress,
// Shared mutable state with DashMap (lock-free)
peers: Arc<DashMap<crate::InstanceId, PathBuf>>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
// Runtime handle for spawning tasks
runtime: OnceLock<tokio::runtime::Handle>,
// Shutdown coordination
cancel_token: CancellationToken,
shutdown_state: OnceLock<ShutdownState>,
// Send channel capacity for backpressure
channel_capacity: usize,
}
/// Handle to a connection's writer task
#[derive(Clone)]
struct ConnectionHandle {
tx: flume::Sender<SendTask>,
}
/// Task sent to writer task containing pre-encoded frame
struct SendTask {
msg_type: MessageType,
header: Bytes,
payload: Bytes,
on_error: Arc<dyn TransportErrorHandler>,
}
impl SendTask {
fn on_error(self, error: impl Into<String>) {
self.on_error
.on_error(self.header, self.payload, error.into());
}
}
impl UdsTransport {
/// Create a new UDS transport
pub fn new(
socket_path: PathBuf,
key: TransportKey,
local_address: WorkerAddress,
channel_capacity: usize,
) -> Self {
Self {
key,
socket_path,
local_address,
peers: Arc::new(DashMap::new()),
connections: Arc::new(DashMap::new()),
runtime: OnceLock::new(),
cancel_token: CancellationToken::new(),
shutdown_state: OnceLock::new(),
channel_capacity,
}
}
/// Get the socket path this transport is bound to
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
/// Optional: Pre-establish connection after registration
pub fn ensure_connected(&self, instance_id: crate::InstanceId) -> Result<()> {
self.get_or_create_connection(instance_id)?;
Ok(())
}
/// Get or create a connection to a peer (lazy initialization)
fn get_or_create_connection(&self, instance_id: crate::InstanceId) -> Result<ConnectionHandle> {
// Fast path: connection already exists and is alive
if let Some(handle) = self.connections.get(&instance_id) {
if !handle.tx.is_disconnected() {
return Ok(handle.clone());
}
// Stale — drop guard before mutating the map
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
}
let rt = self.runtime.get().ok_or(TransportError::NotStarted)?;
// Atomic check-and-insert via entry API
let handle = match self.connections.entry(instance_id) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
if !entry.get().tx.is_disconnected() {
entry.get().clone()
} else {
// Stale entry — replace in-place with a fresh connection
let handle = self.create_connection(instance_id, rt)?;
entry.insert(handle.clone());
handle
}
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
let handle = self.create_connection(instance_id, rt)?;
entry.insert(handle.clone());
handle
}
};
Ok(handle)
}
/// Create a new connection handle and spawn the writer task.
fn create_connection(
&self,
instance_id: crate::InstanceId,
rt: &tokio::runtime::Handle,
) -> Result<ConnectionHandle> {
let path = self
.peers
.get(&instance_id)
.ok_or(TransportError::PeerNotRegistered(instance_id))?
.value()
.clone();
let (tx, rx) = flume::bounded(self.channel_capacity);
let handle = ConnectionHandle { tx };
let cancel = self.cancel_token.clone();
let conns = Arc::clone(&self.connections);
debug!("Created new UDS connection to {} ({:?})", instance_id, path);
rt.spawn(connection_writer_task(path, instance_id, rx, conns, cancel));
Ok(handle)
}
}
impl Transport for UdsTransport {
fn key(&self) -> TransportKey {
self.key.clone()
}
fn address(&self) -> WorkerAddress {
self.local_address.clone()
}
fn register(&self, peer_info: PeerInfo) -> Result<(), TransportError> {
// Get endpoint from peer's address
let endpoint = peer_info
.worker_address()
.get_entry(&self.key)
.map_err(|_| TransportError::NoEndpoint)?
.ok_or(TransportError::NoEndpoint)?;
// Parse UDS endpoint (expected format: "uds:///path/to/socket" or "/path/to/socket")
let path = parse_uds_endpoint(&endpoint).map_err(|e| {
error!("Failed to parse UDS endpoint: {}", e);
TransportError::InvalidEndpoint
})?;
// Store peer path
self.peers.insert(peer_info.instance_id(), path.clone());
debug!("Registered peer {} at {:?}", peer_info.instance_id(), path);
Ok(())
}
#[inline]
fn send_message(
&self,
instance_id: crate::InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: Arc<dyn TransportErrorHandler>,
) {
let header = Bytes::from(header);
let payload = Bytes::from(payload);
let send_msg = SendTask {
msg_type: message_type,
header,
payload,
on_error,
};
// Fast path: try to send on existing connection
let send_msg = match self.connections.get(&instance_id) {
Some(handle) => match handle.tx.try_send(send_msg) {
Ok(()) => return,
Err(flume::TrySendError::Full(send_msg)) => send_msg,
Err(flume::TrySendError::Disconnected(send_msg)) => {
// Drop the guard before mutating the map
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
// Fall through to slow path to create a fresh connection
send_msg
}
},
None => send_msg,
};
// Slow path: create new connection
let rt = match self.runtime.get() {
Some(rt) => rt,
None => {
send_msg.on_error("Transport not started");
return;
}
};
let handle = match self.get_or_create_connection(instance_id) {
Ok(h) => h,
Err(e) => {
send_msg.on_error(format!("Failed to create connection: {}", e));
return;
}
};
rt.spawn(async move {
if let Err(flume::SendError(send_msg)) = handle.tx.send_async(send_msg).await {
send_msg.on_error("Connection closed");
}
});
}
fn start(
&self,
_instance_id: crate::InstanceId,
channels: TransportAdapter,
rt: tokio::runtime::Handle,
) -> futures::future::BoxFuture<'_, anyhow::Result<()>> {
// Store runtime handle for use in send_message
self.runtime.set(rt.clone()).ok();
// Capture shutdown state from the adapter
self.shutdown_state
.set(channels.shutdown_state.clone())
.ok();
let socket_path = self.socket_path.clone();
let shutdown_state = channels.shutdown_state.clone();
Box::pin(async move {
struct DefaultErrorHandler;
impl TransportErrorHandler for DefaultErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
warn!("UDS transport error: {}", error);
}
}
// Remove a stale socket file only when it is safe to do so.
if socket_path.exists() {
let is_socket = std::fs::metadata(&socket_path)
.map(|m| m.file_type().is_socket())
.unwrap_or(false);
if !is_socket {
anyhow::bail!(
"path {:?} exists and is not a Unix domain socket",
socket_path
);
}
// Probe liveness: a successful connect means a live listener owns it.
match tokio::time::timeout(
Duration::from_millis(100),
UnixStream::connect(&socket_path),
)
.await
{
Ok(Ok(_)) => {
anyhow::bail!(
"a live UDS listener is already running at {:?}",
socket_path
);
}
_ => {
// Stale (connection refused / timeout) — safe to unlink.
std::fs::remove_file(&socket_path).ok();
}
}
}
// Build and bind before spawning so that start() only returns Ok
// after the OS-level bind succeeds.
let uds_listener = UdsListener::builder()
.socket_path(socket_path.clone())
.adapter(channels)
.error_handler(Arc::new(DefaultErrorHandler))
.shutdown_state(shutdown_state)
.build()?;
let bound_listener = uds_listener.bind()?;
rt.spawn(async move {
if let Err(e) = bound_listener.serve().await {
error!("UDS listener error: {}", e);
}
});
info!("UDS transport started on {:?}", socket_path);
Ok(())
})
}
fn begin_drain(&self) {
if let Some(state) = self.shutdown_state.get() {
state.begin_drain();
}
}
fn shutdown(&self) {
info!("Shutting down UDS transport");
// Cancel the teardown token (Phase 3) to stop the listener and connection handlers
if let Some(state) = self.shutdown_state.get() {
state.teardown_token().cancel();
}
self.cancel_token.cancel();
// Clear connections
self.connections.clear();
}
fn check_health(
&self,
instance_id: crate::InstanceId,
timeout: Duration,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), HealthCheckError>> + Send + '_>,
> {
Box::pin(async move {
let connection_exists = self.connections.contains_key(&instance_id);
if let Some(handle) = self.connections.get(&instance_id) {
if !handle.tx.is_disconnected() {
return Ok(());
}
// Channel is disconnected — drop guard and remove stale entry
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
}
// No existing connection or connection is dead - verify peer is reachable
let path = self
.peers
.get(&instance_id)
.ok_or(HealthCheckError::PeerNotRegistered)?
.value()
.clone();
// Try to connect (and immediately drop) to verify peer is reachable
match tokio::time::timeout(timeout, UnixStream::connect(&path)).await {
Ok(Ok(_stream)) => {
if connection_exists {
Ok(())
} else {
Err(HealthCheckError::NeverConnected)
}
}
Ok(Err(_)) => Err(HealthCheckError::ConnectionFailed),
Err(_) => Err(HealthCheckError::Timeout),
}
})
}
}
/// Connection writer task for UDS
///
/// Mirrors the TCP connection_writer_task. Cleanup (draining queued messages
/// and removing the stale map entry) always runs, even if the initial connect fails.
async fn connection_writer_task(
path: PathBuf,
instance_id: crate::InstanceId,
rx: flume::Receiver<SendTask>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
cancel_token: CancellationToken,
) -> Result<()> {
let result = connection_writer_inner(&path, instance_id, &rx, &cancel_token).await;
// Always drain queued messages and notify their error handlers.
while let Ok(msg) = rx.try_recv() {
msg.on_error("Connection closed");
}
// Drop the receiver so our sender half becomes disconnected, then remove
// the stale entry. The predicate ensures we only remove our own entry —
// a replacement connection's tx will still be connected.
drop(rx);
connections.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
debug!("UDS connection to {} ({:?}) closed", instance_id, path);
result
}
/// Inner loop: connect and send frames until the channel closes or a write error occurs.
async fn connection_writer_inner(
path: &Path,
instance_id: crate::InstanceId,
rx: &flume::Receiver<SendTask>,
cancel_token: &CancellationToken,
) -> Result<()> {
debug!("Connecting to UDS {:?}", path);
let mut stream = tokio::select! {
_ = cancel_token.cancelled() => return Ok(()),
res = UnixStream::connect(path) => res.context("UDS connect failed")?,
};
debug!("Connected to UDS {:?}", path);
// Main send loop
loop {
let msg = tokio::select! {
_ = cancel_token.cancelled() => break,
res = rx.recv_async() => match res {
Ok(msg) => msg,
Err(_) => break,
},
};
if let Err(e) =
TcpFrameCodec::encode_frame(&mut stream, msg.msg_type, &msg.header, &msg.payload).await
{
error!("Write error to {} ({:?}): {}", instance_id, path, e);
msg.on_error(format!("Failed to write to UDS stream: {}", e));
break;
}
}
Ok(())
}
/// Parse a UDS endpoint string into a PathBuf
///
/// Accepts formats:
/// - "uds:///path/to/socket"
/// - "/path/to/socket"
fn parse_uds_endpoint(endpoint: &[u8]) -> Result<PathBuf> {
let endpoint_str = std::str::from_utf8(endpoint).context("endpoint is not valid UTF-8")?;
// Strip "uds://" prefix if present
let path_str = endpoint_str.strip_prefix("uds://").unwrap_or(endpoint_str);
if path_str.is_empty() {
anyhow::bail!("empty UDS socket path");
}
Ok(PathBuf::from(path_str))
}
/// Builder for UdsTransport
pub struct UdsTransportBuilder {
socket_path: Option<PathBuf>,
key: Option<TransportKey>,
channel_capacity: usize,
}
impl UdsTransportBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
socket_path: None,
key: None,
channel_capacity: 256,
}
}
/// Set the socket path
pub fn socket_path(mut self, path: impl Into<PathBuf>) -> Self {
self.socket_path = Some(path.into());
self
}
/// Set the transport key
pub fn key(mut self, key: TransportKey) -> Self {
self.key = Some(key);
self
}
/// Set the channel capacity for backpressure (default: 256)
pub fn channel_capacity(mut self, capacity: usize) -> Self {
self.channel_capacity = capacity;
self
}
/// Build the UdsTransport
pub fn build(self) -> Result<UdsTransport> {
let socket_path = self
.socket_path
.ok_or_else(|| anyhow::anyhow!("socket_path is required"))?;
let key = self.key.unwrap_or_else(|| TransportKey::from("uds"));
let local_endpoint = format!("uds://{}", socket_path.display());
let mut addr_builder = crate::address::WorkerAddressBuilder::new();
addr_builder.add_entry(key.clone(), local_endpoint.as_bytes().to_vec())?;
let local_address = addr_builder.build()?;
Ok(UdsTransport::new(
socket_path,
key,
local_address,
self.channel_capacity,
))
}
}
impl Default for UdsTransportBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::address::WorkerAddressBuilder;
use std::sync::atomic::{AtomicUsize, Ordering};
use velo_common::PeerInfo;
/// Error handler that discards errors (for tests that don't need to track them).
struct NullErrorHandler;
impl TransportErrorHandler for NullErrorHandler {
fn on_error(&self, _: Bytes, _: Bytes, _: String) {}
}
/// Error handler that counts errors (for tests that verify error routing).
struct TrackingErrorHandler {
count: AtomicUsize,
}
impl TrackingErrorHandler {
fn new() -> Self {
Self {
count: AtomicUsize::new(0),
}
}
fn error_count(&self) -> usize {
self.count.load(Ordering::SeqCst)
}
}
impl TransportErrorHandler for TrackingErrorHandler {
fn on_error(&self, _: Bytes, _: Bytes, _: String) {
self.count.fetch_add(1, Ordering::SeqCst);
}
}
/// Build a `PeerInfo` whose UDS endpoint points at `path`.
fn make_uds_peer(path: &Path) -> PeerInfo {
let instance_id = crate::InstanceId::new_v4();
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("uds", format!("uds://{}", path.display()).into_bytes())
.unwrap();
PeerInfo::new(instance_id, builder.build().unwrap())
}
/// Build a `UdsTransport` with its runtime set, bound to a temp socket path.
/// Returns `(transport, socket_path)`.
fn make_transport() -> (UdsTransport, PathBuf) {
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let socket_path = dir.join("test.sock");
let transport = UdsTransportBuilder::new()
.socket_path(&socket_path)
.build()
.unwrap();
transport
.runtime
.set(tokio::runtime::Handle::current())
.ok();
(transport, socket_path)
}
/// Insert a stale `ConnectionHandle` into the transport's connections map.
fn insert_stale_handle(transport: &UdsTransport, instance_id: crate::InstanceId) {
let (tx, _rx) = flume::bounded::<SendTask>(1);
// Drop _rx immediately so tx.is_disconnected() == true
transport
.connections
.insert(instance_id, ConnectionHandle { tx });
}
#[test]
fn test_parse_uds_endpoint() {
// With uds:// prefix
let path = parse_uds_endpoint(b"uds:///tmp/test.sock").unwrap();
assert_eq!(path, PathBuf::from("/tmp/test.sock"));
// Without prefix
let path = parse_uds_endpoint(b"/var/run/anvil.sock").unwrap();
assert_eq!(path, PathBuf::from("/var/run/anvil.sock"));
// Empty path
assert!(parse_uds_endpoint(b"").is_err());
}
#[test]
fn test_builder_requires_socket_path() {
let result = UdsTransportBuilder::new().build();
assert!(result.is_err());
}
#[test]
fn test_builder_with_socket_path() {
let result = UdsTransportBuilder::new()
.socket_path("/tmp/test.sock")
.build();
assert!(result.is_ok());
}
#[test]
fn test_builder_custom_key() {
let transport = UdsTransportBuilder::new()
.socket_path("/tmp/test.sock")
.key(TransportKey::from("custom-uds"))
.build()
.unwrap();
assert_eq!(transport.key(), TransportKey::from("custom-uds"));
}
#[test]
fn test_transport_socket_path() {
let transport = UdsTransportBuilder::new()
.socket_path("/tmp/test.sock")
.build()
.unwrap();
assert_eq!(transport.socket_path(), Path::new("/tmp/test.sock"));
}
#[tokio::test]
async fn test_get_or_create_connection_replaces_stale_handle() {
let (transport, _socket_path) = make_transport();
// Start a UDS listener that the transport can connect to
let dir = std::env::temp_dir().join(format!("uds-peer-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let peer_socket = dir.join("peer.sock");
let peer_listener = tokio::net::UnixListener::bind(&peer_socket).unwrap();
let peer = make_uds_peer(&peer_socket);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert a stale handle
insert_stale_handle(&transport, iid);
assert!(
transport
.connections
.get(&iid)
.unwrap()
.tx
.is_disconnected()
);
// get_or_create_connection should replace the stale handle with a live one
let handle = transport.get_or_create_connection(iid).unwrap();
assert!(!handle.tx.is_disconnected());
// The map entry should also be live
let entry = transport.connections.get(&iid).unwrap();
assert!(!entry.tx.is_disconnected());
// Cleanup
drop(peer_listener);
std::fs::remove_file(&peer_socket).ok();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_check_health_removes_stale_entry() {
let (transport, _socket_path) = make_transport();
// Start a UDS listener so the peer is "reachable"
let dir = std::env::temp_dir().join(format!("uds-peer-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let peer_socket = dir.join("peer.sock");
let _peer_listener = tokio::net::UnixListener::bind(&peer_socket).unwrap();
let peer = make_uds_peer(&peer_socket);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert stale handle — simulates a dead writer task
insert_stale_handle(&transport, iid);
assert!(transport.connections.contains_key(&iid));
// check_health should remove the stale entry and verify the peer is reachable
let result = transport.check_health(iid, Duration::from_secs(2)).await;
// Stale entry should be gone
assert!(!transport.connections.contains_key(&iid));
// Since there WAS a previous connection entry, check_health returns Ok
assert!(result.is_ok());
// Cleanup
std::fs::remove_file(&peer_socket).ok();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_writer_task_cleans_up_on_write_error() {
// Bind a UDS listener, accept once, then drop everything to cause a write error
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let socket_path = dir.join("writer-test.sock");
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
let iid = crate::InstanceId::new_v4();
let (tx, rx) = flume::bounded::<SendTask>(8);
let connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>> =
Arc::new(DashMap::new());
connections.insert(iid, ConnectionHandle { tx: tx.clone() });
let conns = Arc::clone(&connections);
let cancel = CancellationToken::new();
// Spawn the writer task
let writer = tokio::spawn(connection_writer_task(
socket_path.clone(),
iid,
rx,
conns,
cancel,
));
// Accept the connection, then immediately drop it + the listener
let (stream, _) = listener.accept().await.unwrap();
drop(stream);
drop(listener);
// Send a message — the writer should hit a broken-pipe error
tx.send(SendTask {
msg_type: MessageType::Message,
header: Bytes::from_static(b"hdr"),
payload: Bytes::from_static(b"pay"),
on_error: Arc::new(NullErrorHandler),
})
.unwrap();
// Wait for writer task to finish
let _ = writer.await;
// The writer should have removed the stale entry from the map
assert!(
!connections.contains_key(&iid),
"writer task should clean up its DashMap entry on write error"
);
// Cleanup
std::fs::remove_file(&socket_path).ok();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_send_message_does_not_fail_on_stale_handle() {
let (transport, _socket_path) = make_transport();
// Start a UDS listener that accepts connections (simulates a healthy peer)
let dir = std::env::temp_dir().join(format!("uds-peer-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let peer_socket = dir.join("peer.sock");
let peer_listener = tokio::net::UnixListener::bind(&peer_socket).unwrap();
let peer = make_uds_peer(&peer_socket);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert a stale handle
insert_stale_handle(&transport, iid);
// send_message should detect the stale handle and create a new one
let error_handler = Arc::new(TrackingErrorHandler::new());
transport.send_message(
iid,
b"test-header".to_vec(),
b"test-payload".to_vec(),
MessageType::Message,
error_handler.clone(),
);
// Accept the connection that the new writer task will establish
let (mut stream, _) = peer_listener.accept().await.unwrap();
// Read the framed message from the stream to confirm delivery
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 256];
let n = tokio::time::timeout(Duration::from_secs(2), stream.read(&mut buf))
.await
.expect("timed out waiting for data")
.expect("read error");
assert!(n > 0, "expected data from the writer task");
// No errors should have been reported
assert_eq!(
error_handler.error_count(),
0,
"send_message should retry on stale handle, not fail"
);
// The connections map should now contain a live handle
let entry = transport.connections.get(&iid).unwrap();
assert!(
!entry.tx.is_disconnected(),
"stale handle should have been replaced with a live one"
);
// Cleanup
std::fs::remove_file(&peer_socket).ok();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_double_bind_returns_err() {
use crate::transport::make_channels;
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let socket_path = dir.join("double-bind.sock");
let transport1 = UdsTransportBuilder::new()
.socket_path(&socket_path)
.build()
.unwrap();
let instance_id = crate::InstanceId::new_v4();
let (adapter1, _streams1) = make_channels();
let rt = tokio::runtime::Handle::current();
// First bind must succeed.
transport1
.start(instance_id, adapter1, rt.clone())
.await
.unwrap();
// Second transport on the same path must fail.
let transport2 = UdsTransportBuilder::new()
.socket_path(&socket_path)
.build()
.unwrap();
let (adapter2, _streams2) = make_channels();
let result = transport2.start(instance_id, adapter2, rt).await;
assert!(
result.is_err(),
"start() should return Err when a live listener already owns the socket"
);
// Cleanup
transport1.shutdown();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_begin_drain_activates_draining_flag() {
use crate::transport::make_channels;
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let socket_path = dir.join("drain-test.sock");
let transport = UdsTransportBuilder::new()
.socket_path(&socket_path)
.build()
.unwrap();
let instance_id = crate::InstanceId::new_v4();
let (adapter, _streams) = make_channels();
let rt = tokio::runtime::Handle::current();
transport.start(instance_id, adapter, rt).await.unwrap();
assert!(
!transport.shutdown_state.get().unwrap().is_draining(),
"should not be draining before begin_drain()"
);
transport.begin_drain();
assert!(
transport.shutdown_state.get().unwrap().is_draining(),
"should be draining after begin_drain()"
);
// Cleanup
transport.shutdown();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_writer_task_drains_on_connect_failure() {
// Use a socket path where nothing is listening so connect will fail.
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let dead_socket = dir.join("dead.sock");
let iid = crate::InstanceId::new_v4();
let (tx, rx) = flume::bounded::<SendTask>(8);
let connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>> =
Arc::new(DashMap::new());
connections.insert(iid, ConnectionHandle { tx: tx.clone() });
// Queue a message before the writer task starts
let error_handler = Arc::new(TrackingErrorHandler::new());
tx.send(SendTask {
msg_type: MessageType::Message,
header: Bytes::from_static(b"hdr"),
payload: Bytes::from_static(b"pay"),
on_error: error_handler.clone(),
})
.unwrap();
let conns = Arc::clone(&connections);
let cancel = CancellationToken::new();
let writer = tokio::spawn(connection_writer_task(dead_socket, iid, rx, conns, cancel));
let _ = writer.await;
assert_eq!(
error_handler.error_count(),
1,
"queued message should have its on_error called when connect fails"
);
assert!(
!connections.contains_key(&iid),
"writer task should clean up its DashMap entry on connect failure"
);
// Cleanup
std::fs::remove_dir_all(&dir).ok();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Common test utilities for transport integration tests
//!
//! This module provides a transport-agnostic test infrastructure that can be reused
//! across different transport implementations (TCP, RDMA, UDP, UDS, etc.).
#![allow(dead_code)]
// #[cfg(feature = "grpc")]
// use velo_transports::grpc::{GrpcTransport, GrpcTransportBuilder};
// #[cfg(feature = "http")]
// use velo_transports::http::{HttpTransport, HttpTransportBuilder};
// #[cfg(feature = "nats")]
// use velo_transports::nats::{NatsTransport, NatsTransportBuilder};
use bytes::Bytes;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::timeout;
use velo_transports::{
DataStreams, InstanceId, MessageType, PeerInfo, Transport, TransportErrorHandler,
tcp::{TcpTransport, TcpTransportBuilder},
};
#[cfg(unix)]
use velo_transports::uds::{UdsTransport, UdsTransportBuilder};
use std::sync::Once;
use tracing_subscriber::FmtSubscriber;
#[allow(dead_code)]
static INIT: Once = Once::new();
#[allow(dead_code)]
pub fn init_tracing() {
INIT.call_once(|| {
let _ = FmtSubscriber::builder()
.with_env_filter("trace") // or "info"
.try_init();
});
}
pub mod scenarios;
/// Test error handler that tracks errors for verification
#[derive(Clone)]
pub struct TestErrorHandler {
errors: Arc<Mutex<Vec<(Bytes, Bytes, String)>>>,
}
impl TestErrorHandler {
pub fn new() -> Self {
Self {
errors: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn get_errors(&self) -> Vec<(Bytes, Bytes, String)> {
self.errors.lock().unwrap().clone()
}
pub fn error_count(&self) -> usize {
self.errors.lock().unwrap().len()
}
pub fn clear(&self) {
self.errors.lock().unwrap().clear();
}
}
impl TransportErrorHandler for TestErrorHandler {
fn on_error(&self, header: Bytes, payload: Bytes, error: String) {
self.errors.lock().unwrap().push((header, payload, error));
}
}
/// Handle to a transport instance with its streams for testing
///
/// This is a generic test handle that works with any transport implementation.
/// Use `TestTransportHandle::with_factory()` to create instances with custom transports,
/// or use convenience methods like `TestTransportHandle::new()` for TCP transport.
pub struct TestTransportHandle<T: Transport> {
pub transport: T,
pub streams: DataStreams,
pub instance_id: InstanceId,
pub error_handler: Arc<TestErrorHandler>,
runtime: tokio::runtime::Handle,
}
impl<T: Transport> TestTransportHandle<T> {
/// Create a new test transport using a factory function
///
/// This is the generic constructor that works with any transport implementation.
/// The factory function should create and return a transport instance.
///
/// # Example
/// ```ignore
/// let handle = TestTransportHandle::with_factory(|| {
/// MyTransportBuilder::new().build()
/// }).await?;
/// ```
pub async fn with_factory<F>(factory: F) -> anyhow::Result<Self>
where
F: FnOnce() -> anyhow::Result<T>,
{
let transport = factory()?;
let instance_id = InstanceId::new_v4();
let error_handler = Arc::new(TestErrorHandler::new());
// Create channels for this transport
let (adapter, streams) = velo_transports::make_channels();
// Get runtime handle
let runtime = tokio::runtime::Handle::current();
// Start the transport
transport
.start(instance_id, adapter, runtime.clone())
.await?;
// Give the listener a moment to bind and start accepting connections
tokio::time::sleep(Duration::from_millis(50)).await;
Ok(Self {
transport,
streams,
instance_id,
error_handler,
runtime,
})
}
/// Register another transport as a peer
pub fn register_peer<U: Transport>(
&self,
other: &TestTransportHandle<U>,
) -> anyhow::Result<()> {
let peer_info = PeerInfo::new(other.instance_id, other.transport.address());
self.transport
.register(peer_info)
.map_err(|e| anyhow::anyhow!("Failed to register peer: {:?}", e))?;
Ok(())
}
/// Send a message to a peer
pub fn send(
&self,
target: InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
msg_type: MessageType,
) {
self.transport.send_message(
target,
header,
payload,
msg_type,
self.error_handler.clone(),
);
}
/// Receive a message with timeout
pub async fn recv_message(&self, timeout_duration: Duration) -> anyhow::Result<(Bytes, Bytes)> {
timeout(timeout_duration, self.streams.message_stream.recv_async())
.await
.map_err(|_| anyhow::anyhow!("Timeout waiting for message"))?
.map_err(|e| anyhow::anyhow!("Channel error: {}", e))
}
/// Receive a response with timeout
pub async fn recv_response(
&self,
timeout_duration: Duration,
) -> anyhow::Result<(Bytes, Bytes)> {
timeout(timeout_duration, self.streams.response_stream.recv_async())
.await
.map_err(|_| anyhow::anyhow!("Timeout waiting for response"))?
.map_err(|e| anyhow::anyhow!("Channel error: {}", e))
}
/// Receive an event with timeout
pub async fn recv_event(&self, timeout_duration: Duration) -> anyhow::Result<(Bytes, Bytes)> {
timeout(timeout_duration, self.streams.event_stream.recv_async())
.await
.map_err(|_| anyhow::anyhow!("Timeout waiting for event"))?
.map_err(|e| anyhow::anyhow!("Channel error: {}", e))
}
/// Collect multiple messages with timeout
pub async fn collect_messages(
&self,
count: usize,
timeout_duration: Duration,
) -> anyhow::Result<Vec<(Bytes, Bytes)>> {
let mut messages = Vec::new();
for _ in 0..count {
messages.push(self.recv_message(timeout_duration).await?);
}
Ok(messages)
}
/// Collect multiple messages with timeout, sorted by header for order-independent comparison
///
/// This is useful for testing transports that don't guarantee delivery order (e.g., HTTP).
/// Messages are sorted by header bytes to enable deterministic comparison regardless of
/// delivery order.
pub async fn collect_messages_unordered(
&self,
count: usize,
timeout_duration: Duration,
) -> anyhow::Result<Vec<(Bytes, Bytes)>> {
let mut messages = self.collect_messages(count, timeout_duration).await?;
messages.sort_by(|a, b| a.0.cmp(&b.0));
Ok(messages)
}
/// Collect multiple responses with timeout
pub async fn collect_responses(
&self,
count: usize,
timeout_duration: Duration,
) -> anyhow::Result<Vec<(Bytes, Bytes)>> {
let mut responses = Vec::new();
for _ in 0..count {
responses.push(self.recv_response(timeout_duration).await?);
}
Ok(responses)
}
/// Shutdown the transport
pub fn shutdown(self) {
self.transport.shutdown();
}
}
// TCP-specific convenience constructors
impl TestTransportHandle<TcpTransport> {
/// Create a new TCP transport on a random available port
///
/// This is a convenience method for creating TCP transports.
/// For other transport types, use `with_factory()`.
pub async fn new_tcp() -> anyhow::Result<Self> {
Self::with_factory(|| {
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
TcpTransportBuilder::new().from_listener(listener)?.build()
})
.await
}
/// Alias for `new_tcp()` to maintain backward compatibility
pub async fn new() -> anyhow::Result<Self> {
Self::new_tcp().await
}
}
// UDS-specific convenience constructors
#[cfg(unix)]
impl TestTransportHandle<UdsTransport> {
/// Create a new UDS transport using a temp directory socket path
pub async fn new_uds() -> anyhow::Result<Self> {
Self::with_factory(|| {
let dir = std::env::temp_dir().join(format!(
"velo-uds-test-{}",
velo_transports::InstanceId::new_v4()
));
std::fs::create_dir_all(&dir)?;
let socket_path = dir.join("transport.sock");
UdsTransportBuilder::new().socket_path(&socket_path).build()
})
.await
}
}
// // UCX-specific convenience constructors
// #[cfg(feature = "ucx")]
// impl TestTransportHandle<UcxTransport> {
// /// Create a new UCX transport
// ///
// /// This is a convenience method for creating UCX transports.
// /// For other transport types, use `with_factory()`.
// pub async fn new_ucx() -> anyhow::Result<Self> {
// Self::with_factory(|| UcxTransportBuilder::new().build()).await
// }
// }
// // HTTP-specific convenience constructors
// #[cfg(feature = "http")]
// impl TestTransportHandle<HttpTransport> {
// /// Create a new HTTP transport with OS-provided port
// ///
// /// This is a convenience method for creating HTTP transports.
// /// For other transport types, use `with_factory()`.
// pub async fn new_http() -> anyhow::Result<Self> {
// Self::with_factory(|| {
// // Use default builder which binds to 0.0.0.0:0 (OS-provided port)
// HttpTransportBuilder::new().build()
// })
// .await
// }
// }
// // NATS-specific convenience constructor
// #[cfg(feature = "nats")]
// impl TestTransportHandle<NatsTransport> {
// /// Create a new NATS transport
// ///
// /// This is a convenience method for creating NATS transports.
// /// For other transport types, use `with_factory()`.
// ///
// /// Note: NATS transport requires special handling because it needs the instance_id
// /// at construction time to set up subject subscriptions. We can't use the generic
// /// with_factory() because it creates the instance_id AFTER calling the factory.
// pub async fn new_nats() -> anyhow::Result<Self> {
// // Create instance_id
// let instance_id = InstanceId::new_v4();
// let error_handler = Arc::new(TestErrorHandler::new());
// // Build transport
// let transport = NatsTransportBuilder::new()
// .nats_url("nats://127.0.0.1:4222")
// .build()?;
// // Create channels for this transport
// let (adapter, streams) = velo_transports::make_channels();
// // Get runtime handle
// let runtime = tokio::runtime::Handle::current();
// // Start the transport
// transport
// .start(instance_id, adapter, runtime.clone())
// .await?;
// // Give NATS a moment to establish subscriptions
// tokio::time::sleep(Duration::from_millis(50)).await;
// Ok(Self {
// transport,
// streams,
// instance_id,
// error_handler,
// runtime,
// })
// }
// }
// // gRPC-specific convenience constructors
// #[cfg(feature = "grpc")]
// impl TestTransportHandle<GrpcTransport> {
// /// Create a new gRPC transport with OS-provided port
// ///
// /// This is a convenience method for creating gRPC transports.
// /// For other transport types, use `with_factory()`.
// pub async fn new_grpc() -> anyhow::Result<Self> {
// Self::with_factory(|| {
// // Use default builder which binds to 0.0.0.0:0 (OS-provided port)
// GrpcTransportBuilder::new().build()
// })
// .await
// }
// }
/// Multi-transport test cluster
///
/// A generic cluster that works with any transport implementation.
/// All transports in the cluster are registered with each other in a full mesh topology.
pub struct TestCluster<T: Transport> {
transports: Vec<TestTransportHandle<T>>,
}
impl<T: Transport> TestCluster<T> {
/// Create a new test cluster using a factory function
///
/// This is the generic constructor that works with any transport implementation.
/// The factory function will be called `size` times to create each transport.
///
/// # Example
/// ```ignore
/// let cluster = TestCluster::with_factory(3, || {
/// MyTransportBuilder::new().build()
/// }).await?;
/// ```
pub async fn with_factory<F>(size: usize, factory: F) -> anyhow::Result<Self>
where
F: Fn() -> anyhow::Result<T>,
{
let mut transports = Vec::new();
for _ in 0..size {
transports.push(TestTransportHandle::with_factory(&factory).await?);
}
// Register all peers with each other (full mesh)
for i in 0..transports.len() {
for j in 0..transports.len() {
if i != j {
transports[i].register_peer(&transports[j])?;
}
}
}
Ok(Self { transports })
}
/// Get a transport by index
pub fn get(&self, index: usize) -> &TestTransportHandle<T> {
&self.transports[index]
}
/// Get all transports
pub fn all(&self) -> &[TestTransportHandle<T>] {
&self.transports
}
/// Shutdown all transports
pub fn shutdown(self) {
for transport in self.transports {
transport.shutdown();
}
}
}
// TCP-specific convenience constructor
impl TestCluster<TcpTransport> {
/// Create a new TCP test cluster with the specified number of transports
///
/// This is a convenience method for creating TCP clusters.
/// For other transport types, use `with_factory()`.
pub async fn new(size: usize) -> anyhow::Result<Self> {
Self::with_factory(size, || {
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
TcpTransportBuilder::new().from_listener(listener)?.build()
})
.await
}
}
// UDS-specific convenience constructor
#[cfg(unix)]
impl TestCluster<UdsTransport> {
/// Create a new UDS test cluster with the specified number of transports
pub async fn new_uds(size: usize) -> anyhow::Result<Self> {
Self::with_factory(size, || {
let dir = std::env::temp_dir().join(format!(
"velo-uds-test-{}",
velo_transports::InstanceId::new_v4()
));
std::fs::create_dir_all(&dir)?;
let socket_path = dir.join("transport.sock");
UdsTransportBuilder::new().socket_path(&socket_path).build()
})
.await
}
}
// // HTTP-specific convenience constructor
// #[cfg(feature = "http")]
// impl TestCluster<HttpTransport> {
// /// Create a new HTTP test cluster with the specified number of transports
// ///
// /// This is a convenience method for creating HTTP clusters.
// /// For other transport types, use `with_factory()`.
// pub async fn new_http(size: usize) -> anyhow::Result<Self> {
// Self::with_factory(size, || {
// // Use default builder which binds to OS-provided ports
// HttpTransportBuilder::new().build()
// })
// .await
// }
// }
// // NATS-specific convenience constructor
// #[cfg(feature = "nats")]
// impl TestCluster<NatsTransport> {
// /// Create a new NATS test cluster with the specified number of transports
// ///
// /// This is a convenience method for creating NATS clusters.
// /// For other transport types, use `with_factory()`.
// ///
// /// Note: NATS transport requires special handling because it needs the instance_id
// /// at construction time. We can't use the generic with_factory() which creates
// /// instance_id after calling the factory function.
// pub async fn new_nats(size: usize) -> anyhow::Result<Self> {
// let mut transports = Vec::new();
// for _ in 0..size {
// transports.push(TestTransportHandle::new_nats().await?);
// }
// // Register all peers with each other (full mesh)
// for i in 0..transports.len() {
// for j in 0..transports.len() {
// if i != j {
// transports[i].register_peer(&transports[j])?;
// }
// }
// }
// Ok(Self { transports })
// }
// }
// // gRPC-specific convenience constructor
// #[cfg(feature = "grpc")]
// impl TestCluster<GrpcTransport> {
// /// Create a new gRPC test cluster with the specified number of transports
// ///
// /// This is a convenience method for creating gRPC clusters.
// /// For other transport types, use `with_factory()`.
// pub async fn new_grpc(size: usize) -> anyhow::Result<Self> {
// Self::with_factory(size, || {
// // Use default builder which binds to OS-provided ports
// GrpcTransportBuilder::new().build()
// })
// .await
// }
// }
// Helper utilities
/// Get a random available port
pub fn get_random_port() -> u16 {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
listener.local_addr().unwrap().port()
}
/// Create test data with the specified size
pub fn test_data(size: usize) -> Vec<u8> {
(0..size).map(|i| (i % 256) as u8).collect()
}
/// Create a test message with predictable content
pub fn test_message(id: u32) -> (Vec<u8>, Vec<u8>) {
let header = format!("header-{}", id).into_bytes();
let payload = format!("payload-{}", id).into_bytes();
(header, payload)
}
/// Assert that a received message matches expected values
pub fn assert_message_eq(
received: (Bytes, Bytes),
expected_header: &[u8],
expected_payload: &[u8],
) {
assert_eq!(received.0.as_ref(), expected_header, "Header mismatch");
assert_eq!(received.1.as_ref(), expected_payload, "Payload mismatch");
}
// Transport factory abstraction for parameterized tests
/// Transport factory trait for creating transports in parameterized tests
pub trait TransportFactory {
type Transport: Transport;
async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>>;
async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>>;
}
/// TCP transport factory
pub struct TcpFactory;
impl TransportFactory for TcpFactory {
type Transport = TcpTransport;
async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
TestTransportHandle::new_tcp().await
}
async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
TestCluster::new(size).await
}
}
/// UDS transport factory
#[cfg(unix)]
pub struct UdsFactory;
#[cfg(unix)]
impl TransportFactory for UdsFactory {
type Transport = UdsTransport;
async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
TestTransportHandle::new_uds().await
}
async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
TestCluster::new_uds(size).await
}
}
// /// UCX transport factory
// #[cfg(feature = "ucx")]
// pub struct UcxFactory;
// #[cfg(feature = "ucx")]
// impl TransportFactory for UcxFactory {
// type Transport = UcxTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_ucx().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_ucx(size).await
// }
// }
// /// HTTP transport factory
// #[cfg(feature = "http")]
// pub struct HttpFactory;
// #[cfg(feature = "http")]
// impl TransportFactory for HttpFactory {
// type Transport = HttpTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_http().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_http(size).await
// }
// }
// /// NATS transport factory
// #[cfg(feature = "nats")]
// pub struct NatsFactory;
// #[cfg(feature = "nats")]
// impl TransportFactory for NatsFactory {
// type Transport = NatsTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_nats().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_nats(size).await
// }
// }
// /// gRPC transport factory
// #[cfg(feature = "grpc")]
// pub struct GrpcFactory;
// #[cfg(feature = "grpc")]
// impl TransportFactory for GrpcFactory {
// type Transport = GrpcTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_grpc().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_grpc(size).await
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Generic test scenarios that work with any transport implementation
use super::*;
use std::time::Duration;
const TEST_TIMEOUT: Duration = Duration::from_secs(5);
pub async fn single_message_round_trip<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn bidirectional_messaging<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
transport_b.register_peer(&transport_a).unwrap();
// A -> B
let (header1, payload1) = test_message(1);
transport_a.send(
transport_b.instance_id,
header1.clone(),
payload1.clone(),
MessageType::Message,
);
// B -> A
let (header2, payload2) = test_message(2);
transport_b.send(
transport_a.instance_id,
header2.clone(),
payload2.clone(),
MessageType::Message,
);
let recv_b = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
let recv_a = transport_a.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(recv_b, &header1, &payload1);
assert_message_eq(recv_a, &header2, &payload2);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn multiple_messages_same_connection<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Send 10 messages
for i in 0..10 {
let (header, payload) = test_message(i);
transport_a.send(
transport_b.instance_id,
header,
payload,
MessageType::Message,
);
}
// Receive and verify all messages (order-independent)
let messages = transport_b
.collect_messages_unordered(10, TEST_TIMEOUT)
.await
.unwrap();
// Generate expected messages and sort them the same way
let mut expected: Vec<_> = (0..10).map(test_message).collect();
expected.sort_by(|a, b| a.0.cmp(&b.0));
for (i, msg) in messages.iter().enumerate() {
assert_message_eq(msg.clone(), &expected[i].0, &expected[i].1);
}
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn response_message_type<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Response,
);
let received = transport_b.recv_response(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn event_message_type<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Event,
);
let received = transport_b.recv_event(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn ack_message_type<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Ack,
);
// Acks route to event stream
let received = transport_b.recv_event(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn mixed_message_types<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Send different message types
let (msg_h, msg_p) = test_message(1);
transport_a.send(
transport_b.instance_id,
msg_h.clone(),
msg_p.clone(),
MessageType::Message,
);
let (resp_h, resp_p) = test_message(2);
transport_a.send(
transport_b.instance_id,
resp_h.clone(),
resp_p.clone(),
MessageType::Response,
);
let (event_h, event_p) = test_message(3);
transport_a.send(
transport_b.instance_id,
event_h.clone(),
event_p.clone(),
MessageType::Event,
);
// Receive from appropriate streams
let recv_msg = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
let recv_resp = transport_b.recv_response(TEST_TIMEOUT).await.unwrap();
let recv_event = transport_b.recv_event(TEST_TIMEOUT).await.unwrap();
assert_message_eq(recv_msg, &msg_h, &msg_p);
assert_message_eq(recv_resp, &resp_h, &resp_p);
assert_message_eq(recv_event, &event_h, &event_p);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn large_payload<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// 1MB payload
let header = b"large-payload".to_vec();
let payload = test_data(1024 * 1024);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn empty_header_and_payload<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
transport_a.send(
transport_b.instance_id,
vec![],
vec![],
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &[], &[]);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn cluster_mesh_communication<F: TransportFactory>() {
let cluster = F::create_cluster(3).await.unwrap();
// Each node sends to every other node
for i in 0..3 {
for j in 0..3 {
if i != j {
let (header, payload) = test_message((i * 10 + j) as u32);
cluster.get(i).send(
cluster.get(j).instance_id,
header,
payload,
MessageType::Message,
);
}
}
}
// Each node should receive 2 messages
for i in 0..3 {
let messages = cluster
.get(i)
.collect_messages(2, TEST_TIMEOUT)
.await
.unwrap();
assert_eq!(messages.len(), 2);
}
cluster.shutdown();
}
pub async fn concurrent_senders<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Send from multiple tasks concurrently (without needing to move transport_a)
let target_id = transport_b.instance_id;
let mut handles = vec![];
for i in 0..10 {
let (header, payload) = test_message(i);
// Send directly without spawning - the send itself is non-blocking
transport_a.send(target_id, header, payload, MessageType::Message);
}
// Alternatively test with actual concurrent tasks using a different approach
// Spawn receiver tasks to demonstrate concurrent receives
for _ in 0..10 {
let handle = tokio::spawn(async {
// Just to demonstrate concurrency is working
tokio::time::sleep(Duration::from_micros(1)).await;
});
handles.push(handle);
}
// Wait for all tasks
for handle in handles {
handle.await.unwrap();
}
// Receive all messages
let messages = transport_b
.collect_messages(10, TEST_TIMEOUT)
.await
.unwrap();
assert_eq!(messages.len(), 10);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn send_to_unregistered_peer<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
// Don't register B with A
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
// Give it a moment to process
tokio::time::sleep(Duration::from_millis(100)).await;
// Should have an error
assert!(transport_a.error_handler.error_count() > 0);
let errors = transport_a.error_handler.get_errors();
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].0, header.as_slice());
assert_eq!(errors[0].1, payload.as_slice());
assert!(errors[0].2.to_lowercase().contains("peer not registered"));
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn connection_reuse<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// First message establishes connection
let (header1, payload1) = test_message(1);
transport_a.send(
transport_b.instance_id,
header1.clone(),
payload1.clone(),
MessageType::Message,
);
let recv1 = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(recv1, &header1, &payload1);
// Second message reuses connection
let (header2, payload2) = test_message(2);
transport_a.send(
transport_b.instance_id,
header2.clone(),
payload2.clone(),
MessageType::Message,
);
let recv2 = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(recv2, &header2, &payload2);
// No errors should have occurred
assert_eq!(transport_a.error_handler.error_count(), 0);
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn graceful_shutdown<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Send a message
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
// Receive it
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
// Shutdown should complete without panics
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn high_throughput<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
let num_messages = 100;
// Send many messages
for i in 0..num_messages {
let (header, payload) = test_message(i);
transport_a.send(
transport_b.instance_id,
header,
payload,
MessageType::Message,
);
}
// Receive all messages (order-independent)
let messages = transport_b
.collect_messages_unordered(num_messages as usize, TEST_TIMEOUT)
.await
.unwrap();
assert_eq!(messages.len(), num_messages as usize);
// Generate expected messages and sort them the same way
let mut expected: Vec<_> = (0..num_messages).map(test_message).collect();
expected.sort_by(|a, b| a.0.cmp(&b.0));
// Verify all messages received correctly
for (i, msg) in messages.iter().enumerate() {
assert_message_eq(msg.clone(), &expected[i].0, &expected[i].1);
}
transport_a.shutdown();
transport_b.shutdown();
}
pub async fn zero_copy_efficiency<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Large payload to test zero-copy
let header = b"zero-copy-test".to_vec();
let payload = test_data(512 * 1024); // 512KB
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
// Verify no errors
assert_eq!(transport_a.error_handler.error_count(), 0);
transport_a.shutdown();
transport_b.shutdown();
}
// --- Drain / shutdown scenarios ---
/// After begin_drain on B, messages sent from A to B should NOT arrive on B's message_stream.
pub async fn drain_rejects_messages<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Begin drain on B (both transport-level and shutdown-state, mirroring VeloBackend::graceful_shutdown)
transport_b.transport.begin_drain();
transport_b.streams.shutdown_state.begin_drain();
tokio::time::sleep(Duration::from_millis(100)).await;
// A sends a Message to B
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header,
payload,
MessageType::Message,
);
// B's message_stream should be empty (message rejected during drain)
let result = tokio::time::timeout(
Duration::from_millis(500),
transport_b.streams.message_stream.recv_async(),
)
.await;
assert!(
result.is_err(),
"Expected timeout — messages should be rejected during drain"
);
transport_a.transport.shutdown();
transport_b.streams.shutdown_state.teardown_token().cancel();
transport_b.transport.shutdown();
}
/// After begin_drain on B, responses sent from A to B should still arrive on B's response_stream.
pub async fn drain_accepts_responses<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Begin drain on B
transport_b.transport.begin_drain();
transport_b.streams.shutdown_state.begin_drain();
tokio::time::sleep(Duration::from_millis(100)).await;
// A sends a Response to B
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Response,
);
// B's response_stream should still receive it
let received = transport_b.recv_response(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.transport.shutdown();
transport_b.streams.shutdown_state.teardown_token().cancel();
transport_b.transport.shutdown();
}
/// After begin_drain on B, events sent from A to B should still arrive on B's event_stream.
pub async fn drain_accepts_events<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Begin drain on B
transport_b.transport.begin_drain();
transport_b.streams.shutdown_state.begin_drain();
tokio::time::sleep(Duration::from_millis(100)).await;
// A sends an Event to B
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Event,
);
// B's event_stream should still receive it
let received = transport_b.recv_event(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
transport_a.transport.shutdown();
transport_b.streams.shutdown_state.teardown_token().cancel();
transport_b.transport.shutdown();
}
/// After begin_drain on B, health checks from A to B should still succeed.
pub async fn health_during_drain<F: TransportFactory>() {
let transport_a = F::create().await.unwrap();
let transport_b = F::create().await.unwrap();
transport_a.register_peer(&transport_b).unwrap();
// Establish a connection first: send a message and receive it
let (header, payload) = test_message(1);
transport_a.send(
transport_b.instance_id,
header.clone(),
payload.clone(),
MessageType::Message,
);
let received = transport_b.recv_message(TEST_TIMEOUT).await.unwrap();
assert_message_eq(received, &header, &payload);
// Begin drain on B
transport_b.transport.begin_drain();
transport_b.streams.shutdown_state.begin_drain();
tokio::time::sleep(Duration::from_millis(100)).await;
// A checks health of B — should still succeed during drain
let result = transport_a
.transport
.check_health(transport_b.instance_id, Duration::from_secs(2))
.await;
assert!(
result.is_ok(),
"Health check should succeed during drain: {:?}",
result.err()
);
transport_a.transport.shutdown();
transport_b.streams.shutdown_state.teardown_token().cancel();
transport_b.transport.shutdown();
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for TCP transport
mod common;
use common::{TcpFactory, scenarios};
#[tokio::test]
async fn test_single_message_round_trip() {
scenarios::single_message_round_trip::<TcpFactory>().await;
}
#[tokio::test]
async fn test_bidirectional_messaging() {
scenarios::bidirectional_messaging::<TcpFactory>().await;
}
#[tokio::test]
async fn test_multiple_messages_same_connection() {
scenarios::multiple_messages_same_connection::<TcpFactory>().await;
}
#[tokio::test]
async fn test_response_message_type() {
scenarios::response_message_type::<TcpFactory>().await;
}
#[tokio::test]
async fn test_event_message_type() {
scenarios::event_message_type::<TcpFactory>().await;
}
#[tokio::test]
async fn test_ack_message_type() {
scenarios::ack_message_type::<TcpFactory>().await;
}
#[tokio::test]
async fn test_mixed_message_types() {
scenarios::mixed_message_types::<TcpFactory>().await;
}
#[tokio::test]
async fn test_large_payload() {
scenarios::large_payload::<TcpFactory>().await;
}
#[tokio::test]
async fn test_empty_header_and_payload() {
scenarios::empty_header_and_payload::<TcpFactory>().await;
}
#[tokio::test]
async fn test_cluster_mesh_communication() {
scenarios::cluster_mesh_communication::<TcpFactory>().await;
}
#[tokio::test]
async fn test_concurrent_senders() {
scenarios::concurrent_senders::<TcpFactory>().await;
}
#[tokio::test]
async fn test_send_to_unregistered_peer() {
scenarios::send_to_unregistered_peer::<TcpFactory>().await;
}
#[tokio::test]
async fn test_connection_reuse() {
scenarios::connection_reuse::<TcpFactory>().await;
}
#[tokio::test]
async fn test_graceful_shutdown() {
scenarios::graceful_shutdown::<TcpFactory>().await;
}
#[tokio::test]
async fn test_high_throughput() {
scenarios::high_throughput::<TcpFactory>().await;
}
#[tokio::test]
async fn test_zero_copy_efficiency() {
scenarios::zero_copy_efficiency::<TcpFactory>().await;
}
#[tokio::test]
async fn test_drain_rejects_messages() {
scenarios::drain_rejects_messages::<TcpFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_responses() {
scenarios::drain_accepts_responses::<TcpFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_events() {
scenarios::drain_accepts_events::<TcpFactory>().await;
}
#[tokio::test]
async fn test_health_during_drain() {
scenarios::health_during_drain::<TcpFactory>().await;
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for TCP graceful shutdown
//!
//! These tests verify the 3-phase shutdown behavior:
//! 1. Gate: new Message frames are rejected with ShuttingDown
//! 2. Drain: in-flight work completes, responses/events still flow
//! 3. Teardown: listener and writer tasks exit
mod common;
use bytes::Bytes;
use std::time::Duration;
use tokio::time::{sleep, timeout};
use velo_transports::tcp::TcpFrameCodec;
use velo_transports::{MessageType, Transport};
use common::TestTransportHandle;
/// Helper: connect a raw TCP client to the transport's bind address and send a frame.
async fn connect_and_send_frame(
addr: std::net::SocketAddr,
msg_type: MessageType,
header: &[u8],
payload: &[u8],
) -> tokio::net::TcpStream {
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
TcpFrameCodec::encode_frame(&mut stream, msg_type, header, payload)
.await
.unwrap();
stream
}
/// Helper: read one frame from a raw TCP stream.
async fn read_one_frame(stream: &mut tokio::net::TcpStream) -> (MessageType, Bytes, Bytes) {
use futures::StreamExt;
use tokio_util::codec::Framed;
let mut framed = Framed::new(stream, TcpFrameCodec::new());
framed.next().await.unwrap().unwrap()
}
/// Get the bind address from a TcpTransport by parsing its WorkerAddress.
fn get_bind_addr(
handle: &TestTransportHandle<velo_transports::tcp::TcpTransport>,
) -> std::net::SocketAddr {
let addr = handle.transport.address();
let key = handle.transport.key();
let endpoint = addr.get_entry(&key).unwrap().unwrap();
let s = std::str::from_utf8(&endpoint).unwrap();
let s = s.strip_prefix("tcp://").unwrap_or(s);
s.parse().unwrap()
}
// --- Test 18: Drain rejects Message frames ---
#[tokio::test]
async fn test_tcp_drain_rejects_messages() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
// Begin drain
handle.streams.shutdown_state.begin_drain();
// Give listener time to be ready
sleep(Duration::from_millis(50)).await;
// Connect and send a Message frame
let mut stream =
connect_and_send_frame(addr, MessageType::Message, b"req-header", b"req-payload").await;
// Should get ShuttingDown back
let (msg_type, header, payload) = read_one_frame(&mut stream).await;
assert_eq!(msg_type, MessageType::ShuttingDown);
assert_eq!(&header[..], b"req-header"); // Original header echoed back
assert_eq!(payload.len(), 0); // Empty payload
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 19: Drain accepts Response frames ---
#[tokio::test]
async fn test_tcp_drain_accepts_responses() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
// Begin drain
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Connect and send a Response frame
connect_and_send_frame(addr, MessageType::Response, b"resp-header", b"resp-payload").await;
// Should arrive on the response stream
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"resp-header");
assert_eq!(&payload[..], b"resp-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 20: Drain accepts Event frames ---
#[tokio::test]
async fn test_tcp_drain_accepts_events() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
connect_and_send_frame(addr, MessageType::Event, b"evt-header", b"evt-payload").await;
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.event_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"evt-header");
assert_eq!(&payload[..], b"evt-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 21: New connection during drain still accepts responses ---
#[tokio::test]
async fn test_tcp_new_connection_during_drain() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
// Begin drain BEFORE connecting
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Establish a NEW connection after drain starts
connect_and_send_frame(addr, MessageType::Response, b"new-resp", b"new-payload").await;
// Should arrive on the response stream
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"new-resp");
assert_eq!(&payload[..], b"new-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 22: ShuttingDown frame roundtrip ---
#[test]
fn test_shutting_down_frame_roundtrip() {
use bytes::BytesMut;
use tokio_util::codec::Decoder;
let header = b"correlation-header";
let payload = b"";
// Encode ShuttingDown frame
let mut buf = Vec::new();
TcpFrameCodec::encode_frame_sync(&mut buf, MessageType::ShuttingDown, header, payload).unwrap();
// Decode it
let mut codec = TcpFrameCodec::new();
let mut bytes = BytesMut::from(&buf[..]);
let (msg_type, decoded_header, decoded_payload) = codec.decode(&mut bytes).unwrap().unwrap();
assert_eq!(msg_type, MessageType::ShuttingDown);
assert_eq!(&decoded_header[..], header);
assert_eq!(decoded_payload.len(), 0);
}
// --- Test 23: Full graceful shutdown lifecycle ---
#[tokio::test]
async fn test_tcp_graceful_shutdown_lifecycle() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
let addr = get_bind_addr(&handle);
// Verify normal operation: send a message, receive it
connect_and_send_frame(addr, MessageType::Message, b"normal-msg", b"normal-pay").await;
let (header, _payload) = timeout(
Duration::from_secs(2),
handle.streams.message_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"normal-msg");
// Acquire an InFlightGuard (simulate in-progress request)
let guard = handle.streams.shutdown_state.acquire();
assert_eq!(handle.streams.shutdown_state.in_flight_count(), 1);
// Begin drain (Phase 1)
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Verify new messages are rejected
let mut stream = connect_and_send_frame(addr, MessageType::Message, b"reject-me", b"").await;
let (msg_type, _, _) = read_one_frame(&mut stream).await;
assert_eq!(msg_type, MessageType::ShuttingDown);
// Verify responses still flow
connect_and_send_frame(addr, MessageType::Response, b"still-ok", b"data").await;
let (header, _) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"still-ok");
// Spawn graceful_shutdown in background (will block on drain since guard is held)
let shutdown_state = handle.streams.shutdown_state.clone();
let shutdown_handle = tokio::spawn(async move {
// Phase 2: wait for drain
shutdown_state.wait_for_drain().await;
// Phase 3: teardown
shutdown_state.teardown_token().cancel();
});
// Verify shutdown hasn't completed yet (guard still held)
sleep(Duration::from_millis(100)).await;
assert!(!shutdown_handle.is_finished());
// Drop guard → drain completes → teardown fires
drop(guard);
timeout(Duration::from_secs(2), shutdown_handle)
.await
.expect("shutdown should complete")
.unwrap();
assert!(
handle
.streams
.shutdown_state
.teardown_token()
.is_cancelled()
);
}
// --- Test 24: Shutdown timeout forces teardown ---
#[tokio::test]
async fn test_tcp_shutdown_timeout_forces_teardown() {
let handle = TestTransportHandle::new_tcp().await.unwrap();
// Acquire guard and hold it
let _guard = handle.streams.shutdown_state.acquire();
let shutdown_state = handle.streams.shutdown_state.clone();
let shutdown_handle = tokio::spawn(async move {
shutdown_state.begin_drain();
// Phase 2: wait with short timeout
let _ =
tokio::time::timeout(Duration::from_millis(100), shutdown_state.wait_for_drain()).await;
// Phase 3: teardown (forced, guard still held)
shutdown_state.teardown_token().cancel();
});
timeout(Duration::from_secs(2), shutdown_handle)
.await
.expect("shutdown should complete via timeout")
.unwrap();
// Teardown should have fired even though guard is held
assert!(
handle
.streams
.shutdown_state
.teardown_token()
.is_cancelled()
);
// Guard is still held (not a problem — teardown was forced)
assert_eq!(handle.streams.shutdown_state.in_flight_count(), 1);
}
// --- Test 25: Outbound sends during drain ---
#[tokio::test]
async fn test_outbound_sends_during_drain() {
// Create two transports and register them as peers
let handle_a = TestTransportHandle::new_tcp().await.unwrap();
let handle_b = TestTransportHandle::new_tcp().await.unwrap();
handle_a.register_peer(&handle_b).unwrap();
handle_b.register_peer(&handle_a).unwrap();
// Begin drain on transport A
handle_a.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Send a Response from A to B (outbound sends should work during drain)
handle_a.send(
handle_b.instance_id,
b"response-hdr".to_vec(),
b"response-pay".to_vec(),
MessageType::Response,
);
// B should receive the response
let (header, payload) = timeout(
Duration::from_secs(2),
handle_b.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"response-hdr");
assert_eq!(&payload[..], b"response-pay");
handle_a.streams.shutdown_state.teardown_token().cancel();
handle_b.streams.shutdown_state.teardown_token().cancel();
}
// --- Test 26: Connection writer exits on teardown ---
#[tokio::test]
async fn test_connection_writer_exits_on_teardown() {
let handle_a = TestTransportHandle::new_tcp().await.unwrap();
let handle_b = TestTransportHandle::new_tcp().await.unwrap();
handle_a.register_peer(&handle_b).unwrap();
// Send a message to establish the connection writer task
handle_a.send(
handle_b.instance_id,
b"setup".to_vec(),
b"data".to_vec(),
MessageType::Message,
);
// Wait for it to arrive
timeout(
Duration::from_secs(2),
handle_b.streams.message_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
// Shutdown transport A
handle_a.transport.shutdown();
// Give writer tasks time to exit
sleep(Duration::from_millis(200)).await;
// Sending after shutdown: the cancel_token is already cancelled, so any new
// writer task returns immediately without connecting. The error handler must
// be invoked and the message must not arrive on handle_b.
handle_a.error_handler.clear();
handle_a.send(
handle_b.instance_id,
b"should-fail".to_vec(),
b"data".to_vec(),
MessageType::Message,
);
// Give time for the async error path to complete
sleep(Duration::from_millis(100)).await;
// Error handler must have been invoked for the failed send
assert!(
handle_a.error_handler.error_count() >= 1,
"error handler should be invoked for post-shutdown send"
);
// The message must not have been delivered to handle_b
let not_delivered = timeout(
Duration::from_millis(100),
handle_b.streams.message_stream.recv_async(),
)
.await;
assert!(
not_delivered.is_err(),
"post-shutdown message must not arrive at handle_b"
);
handle_b.streams.shutdown_state.teardown_token().cancel();
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for UDS transport
#![cfg(unix)]
mod common;
use common::{UdsFactory, scenarios};
#[tokio::test]
async fn test_single_message_round_trip() {
scenarios::single_message_round_trip::<UdsFactory>().await;
}
#[tokio::test]
async fn test_bidirectional_messaging() {
scenarios::bidirectional_messaging::<UdsFactory>().await;
}
#[tokio::test]
async fn test_multiple_messages_same_connection() {
scenarios::multiple_messages_same_connection::<UdsFactory>().await;
}
#[tokio::test]
async fn test_response_message_type() {
scenarios::response_message_type::<UdsFactory>().await;
}
#[tokio::test]
async fn test_event_message_type() {
scenarios::event_message_type::<UdsFactory>().await;
}
#[tokio::test]
async fn test_ack_message_type() {
scenarios::ack_message_type::<UdsFactory>().await;
}
#[tokio::test]
async fn test_mixed_message_types() {
scenarios::mixed_message_types::<UdsFactory>().await;
}
#[tokio::test]
async fn test_large_payload() {
scenarios::large_payload::<UdsFactory>().await;
}
#[tokio::test]
async fn test_empty_header_and_payload() {
scenarios::empty_header_and_payload::<UdsFactory>().await;
}
#[tokio::test]
async fn test_cluster_mesh_communication() {
scenarios::cluster_mesh_communication::<UdsFactory>().await;
}
#[tokio::test]
async fn test_concurrent_senders() {
scenarios::concurrent_senders::<UdsFactory>().await;
}
#[tokio::test]
async fn test_send_to_unregistered_peer() {
scenarios::send_to_unregistered_peer::<UdsFactory>().await;
}
#[tokio::test]
async fn test_connection_reuse() {
scenarios::connection_reuse::<UdsFactory>().await;
}
#[tokio::test]
async fn test_graceful_shutdown() {
scenarios::graceful_shutdown::<UdsFactory>().await;
}
#[tokio::test]
async fn test_high_throughput() {
scenarios::high_throughput::<UdsFactory>().await;
}
#[tokio::test]
async fn test_zero_copy_efficiency() {
scenarios::zero_copy_efficiency::<UdsFactory>().await;
}
#[tokio::test]
async fn test_drain_rejects_messages() {
scenarios::drain_rejects_messages::<UdsFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_responses() {
scenarios::drain_accepts_responses::<UdsFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_events() {
scenarios::drain_accepts_events::<UdsFactory>().await;
}
#[tokio::test]
async fn test_health_during_drain() {
scenarios::health_during_drain::<UdsFactory>().await;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment