Unverified Commit 63d7c01c authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: velo-backend (#6547)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent 2cc92bfa
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! High-performance TCP listener for ActiveMessage transport
//!
//! This module provides a TCP server that accepts incoming connections,
//! decodes framed messages using zero-copy techniques, and routes them
//! to the appropriate transport streams.
use anyhow::{Context, Result};
use bytes::Bytes;
use futures::StreamExt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener as TokioTcpListener, TcpStream};
use tokio::runtime::{Handle, Runtime};
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};
use crate::{MessageType, ShutdownState, TransportAdapter, TransportErrorHandler};
use super::framing::TcpFrameCodec;
/// Runtime configuration for the TCP listener
pub enum RuntimeConfig {
/// Use an existing tokio runtime handle.
Handle(Handle),
/// Use a provided tokio runtime.
Runtime(Arc<Runtime>),
/// Create a single-threaded runtime pinned to the specified CPU core (Linux only).
CpuPin(usize),
}
/// High-performance TCP listener for ActiveMessage transport
///
/// This listener accepts incoming TCP connections and routes decoded frames
/// to the appropriate transport streams with zero-copy performance.
pub struct TcpListener {
bind_addr: SocketAddr,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
runtime_config: RuntimeConfig,
listener: Option<std::net::TcpListener>,
}
impl TcpListener {
/// Create a new builder for TcpListener
pub fn builder() -> TcpListenerBuilder {
TcpListenerBuilder::new()
}
/// Start the listener and serve incoming connections
///
/// This method blocks or spawns based on the runtime configuration:
/// - For Handle/Runtime: spawns tasks and returns immediately
/// - For CpuPin: creates a pinned runtime and blocks until cancellation
pub async fn serve(mut self) -> Result<()> {
// Extract runtime config to avoid borrow issues
let runtime_config = std::mem::replace(
&mut self.runtime_config,
RuntimeConfig::Handle(Handle::current()),
);
match runtime_config {
RuntimeConfig::Handle(handle) => {
handle.spawn(async move {
if let Err(e) = self.run_server().await {
error!("TCP listener error: {}", e);
}
});
Ok(())
}
RuntimeConfig::Runtime(rt) => {
rt.spawn(async move {
if let Err(e) = self.run_server().await {
error!("TCP listener error: {}", e);
}
});
Ok(())
}
RuntimeConfig::CpuPin(cpu_id) => {
let rt = Self::create_pinned_runtime(cpu_id)
.context("Failed to create CPU-pinned runtime")?;
rt.block_on(self.run_server())
}
}
}
/// Create a single-threaded runtime pinned to a specific CPU core
#[cfg(target_os = "linux")]
fn create_pinned_runtime(cpu_id: usize) -> Result<Runtime> {
use nix::sched::{CpuSet, sched_setaffinity};
use nix::unistd::Pid;
tokio::runtime::Builder::new_current_thread()
.enable_all()
.thread_name("tcp-listener-pinned")
.on_thread_start(move || {
let mut cpu_set = CpuSet::new();
if cpu_set.set(cpu_id).is_ok() {
if let Err(e) = sched_setaffinity(Pid::from_raw(0), &cpu_set) {
error!("Failed to pin thread to CPU {}: {}", cpu_id, e);
} else {
debug!("Successfully pinned TCP listener to CPU {}", cpu_id);
}
}
})
.build()
.context("Failed to build tokio runtime")
}
/// Create a single-threaded runtime without CPU pinning (non-Linux platforms)
#[cfg(not(target_os = "linux"))]
fn create_pinned_runtime(cpu_id: usize) -> Result<Runtime> {
warn!(
"CPU pinning requested (CPU {}) but not supported on this platform",
cpu_id
);
tokio::runtime::Builder::new_current_thread()
.enable_all()
.thread_name("tcp-listener")
.build()
.context("Failed to build tokio runtime")
}
/// Main server loop that accepts connections
async fn run_server(self) -> Result<()> {
// Use pre-bound listener if provided, otherwise bind to the address
let listener = if let Some(std_listener) = self.listener {
// Set non-blocking for tokio conversion
std_listener
.set_nonblocking(true)
.context("Failed to set listener to non-blocking")?;
TokioTcpListener::from_std(std_listener)
.context("Failed to convert std TcpListener to tokio TcpListener")?
} else {
TokioTcpListener::bind(self.bind_addr)
.await
.context(format!("Failed to bind TCP listener to {}", self.bind_addr))?
};
let local_addr = listener
.local_addr()
.context("Failed to get local address")?;
info!("TCP listener bound to {}", local_addr);
let teardown_token = self.shutdown_state.teardown_token().clone();
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, peer_addr)) => {
debug!("Accepted TCP connection from {}", peer_addr);
let adapter = self.adapter.clone();
let error_handler = self.error_handler.clone();
let shutdown_state = self.shutdown_state.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(
stream,
peer_addr,
adapter,
error_handler,
shutdown_state,
)
.await
{
warn!("Error handling connection from {}: {}", peer_addr, e);
}
});
}
Err(e) => {
error!("Failed to accept TCP connection: {}", e);
}
}
}
_ = teardown_token.cancelled() => {
info!("TCP listener shutting down (teardown)");
break;
}
}
}
Ok(())
}
/// Handle a single TCP connection
async fn handle_connection(
stream: TcpStream,
peer_addr: SocketAddr,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
) -> Result<()> {
debug!("Configuring connection from {}", peer_addr);
// Configure socket for high performance
if let Err(e) = stream.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY on {}: {}", peer_addr, e);
}
#[allow(deprecated)] // Intentional: linger ensures clean socket shutdown
if let Err(e) = stream.set_linger(Some(Duration::from_secs(1))) {
warn!("Failed to set linger on {}: {}", peer_addr, e);
}
// Set keep-alive to detect dead connections
let keepalive = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10));
let sock_ref = socket2::SockRef::from(&stream);
if let Err(e) = sock_ref.set_tcp_keepalive(&keepalive) {
warn!("Failed to set TCP keepalive on {}: {}", peer_addr, e);
}
// Set large receive buffer for high throughput
if let Err(e) = sock_ref.set_recv_buffer_size(1_048_576) {
warn!("Failed to set receive buffer size on {}: {}", peer_addr, e);
}
// Create framed stream with zero-copy codec
let mut framed = Framed::new(stream, TcpFrameCodec::new());
let teardown_token = shutdown_state.teardown_token().clone();
debug!("Connection from {} ready for frames", peer_addr);
loop {
tokio::select! {
frame_result = framed.next() => {
match frame_result {
Some(Ok((msg_type, header, payload))) => {
// During drain: reject new Message frames with ShuttingDown,
// but always pass through Response/Ack/Event frames.
if shutdown_state.is_draining() && msg_type == MessageType::Message {
debug!(
"Rejecting Message frame from {} during drain (sending ShuttingDown)",
peer_addr
);
// Echo original header back for correlation, empty payload
if let Err(e) = TcpFrameCodec::encode_frame(
framed.get_mut(),
MessageType::ShuttingDown,
&header,
&[],
)
.await
{
warn!(
"Failed to send ShuttingDown frame to {}: {}",
peer_addr, e
);
}
continue;
}
// Route frame to appropriate stream based on type
if let Err(e) = Self::route_frame(
msg_type,
header,
payload,
&adapter,
&error_handler,
)
.await
{
warn!(
"Failed to route {:?} frame from {}: {}",
msg_type, peer_addr, e
);
}
}
Some(Err(e)) => {
error!("Frame decode error from {}: {}", peer_addr, e);
break;
}
None => {
// Connection closed gracefully (FIN received)
debug!("Connection from {} closed gracefully", peer_addr);
break;
}
}
}
_ = teardown_token.cancelled() => {
debug!("Connection handler for {} torn down", peer_addr);
break;
}
}
}
Ok(())
}
/// Route a decoded frame to the appropriate stream
///
/// This function performs zero-copy routing by transferring ownership of
/// the Bytes to the flume channel. On error, it invokes the error callback
/// with the original data (requiring a clone).
async fn route_frame(
msg_type: MessageType,
header: Bytes,
payload: Bytes,
adapter: &TransportAdapter,
error_handler: &Arc<dyn TransportErrorHandler>,
) -> Result<()> {
let sender = match msg_type {
MessageType::Message => &adapter.message_stream,
MessageType::Response => &adapter.response_stream,
MessageType::Ack | MessageType::Event => &adapter.event_stream,
MessageType::ShuttingDown => {
// ShuttingDown is an outbound-only frame type; receiving it here
// means a remote peer rejected our request. Route to the response
// stream so higher layers can handle the rejection via correlation.
&adapter.response_stream
}
};
// Try to send with ownership transfer (zero-copy)
match sender.send_async((header, payload)).await {
Ok(_) => Ok(()),
Err(e) => {
// Send failed - invoke error callback with the data
error_handler.on_error(
e.0.0, // header
e.0.1, // payload
format!("Failed to route {:?}", msg_type),
);
Err(anyhow::anyhow!("Failed to send to stream"))
}
}
}
}
/// Builder for TcpListener
pub struct TcpListenerBuilder {
bind_addr: Option<SocketAddr>,
adapter: Option<TransportAdapter>,
error_handler: Option<Arc<dyn TransportErrorHandler>>,
shutdown_state: Option<ShutdownState>,
runtime_config: Option<RuntimeConfig>,
listener: Option<std::net::TcpListener>,
}
impl TcpListenerBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
bind_addr: None,
adapter: None,
error_handler: None,
shutdown_state: None,
runtime_config: None,
listener: None,
}
}
/// Set the bind address
pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
self.bind_addr = Some(addr);
self
}
/// Set the transport adapter
pub fn adapter(mut self, adapter: TransportAdapter) -> Self {
self.adapter = Some(adapter);
self
}
/// Set the error handler
pub fn error_handler(mut self, handler: Arc<dyn TransportErrorHandler>) -> Self {
self.error_handler = Some(handler);
self
}
/// Set the shutdown state for graceful drain coordination
pub fn shutdown_state(mut self, state: ShutdownState) -> Self {
self.shutdown_state = Some(state);
self
}
/// Use an existing tokio runtime handle
pub fn with_handle(mut self, handle: Handle) -> Self {
self.runtime_config = Some(RuntimeConfig::Handle(handle));
self
}
/// Use a provided tokio runtime
pub fn with_runtime(mut self, runtime: Arc<Runtime>) -> Self {
self.runtime_config = Some(RuntimeConfig::Runtime(runtime));
self
}
/// Create a single-threaded runtime pinned to a specific CPU core
pub fn with_cpu_pin(mut self, cpu_id: usize) -> Self {
self.runtime_config = Some(RuntimeConfig::CpuPin(cpu_id));
self
}
/// Use a pre-bound TcpListener
///
/// This is useful for tests where you want to bind to port 0 and avoid port races.
/// When provided, the bind_addr should still be set (for logging/debugging purposes).
pub fn listener(mut self, listener: Option<std::net::TcpListener>) -> Self {
self.listener = listener;
self
}
/// Build the TcpListener
pub fn build(self) -> Result<TcpListener> {
let bind_addr = self
.bind_addr
.ok_or_else(|| anyhow::anyhow!("bind_addr is required"))?;
let adapter = self
.adapter
.ok_or_else(|| anyhow::anyhow!("adapter is required"))?;
let error_handler = self
.error_handler
.ok_or_else(|| anyhow::anyhow!("error_handler is required"))?;
let shutdown_state = self.shutdown_state.unwrap_or_default();
let runtime_config = self
.runtime_config
.unwrap_or_else(|| RuntimeConfig::Handle(Handle::current()));
Ok(TcpListener {
bind_addr,
adapter,
error_handler,
shutdown_state,
runtime_config,
listener: self.listener,
})
}
}
impl Default for TcpListenerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::make_channels;
use std::net::{IpAddr, Ipv4Addr};
struct TestErrorHandler;
impl TransportErrorHandler for TestErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
eprintln!("Test error handler: {}", error);
}
}
#[test]
fn test_builder_requires_fields() {
let result = TcpListener::builder().build();
assert!(result.is_err());
}
#[tokio::test]
async fn test_builder_with_all_fields() {
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
let (adapter, _streams) = make_channels();
let error_handler = Arc::new(TestErrorHandler);
let result = TcpListener::builder()
.bind_addr(bind_addr)
.adapter(adapter)
.error_handler(error_handler)
.build();
assert!(result.is_ok());
}
#[test]
fn test_builder_with_cpu_pin() {
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
let (adapter, _streams) = make_channels();
let error_handler = Arc::new(TestErrorHandler);
let result = TcpListener::builder()
.bind_addr(bind_addr)
.adapter(adapter)
.error_handler(error_handler)
.with_cpu_pin(0)
.build();
assert!(result.is_ok());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! TCP Transport Module
//!
//! This module provides a high-performance TCP transport implementation with:
//! - Zero-copy frame codec for minimal overhead
//! - CPU pinning support for predictable latency
//! - Frame type routing (Message, Response, Ack, Event)
//! - Graceful shutdown with proper FIN handling
//! - Keep-alive for dead connection detection
mod framing;
mod listener;
mod transport;
pub use framing::TcpFrameCodec;
pub use listener::{RuntimeConfig, TcpListener, TcpListenerBuilder};
pub use transport::{TcpTransport, TcpTransportBuilder};
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use bytes::Bytes;
use futures::future::BoxFuture;
use crate::{InstanceId, PeerInfo, TransportKey, WorkerAddress};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::{sync::Arc, time::Duration};
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
/// Errors returned by individual [`Transport`] implementations.
#[derive(thiserror::Error, Debug)]
pub enum TransportError {
/// The peer's [`WorkerAddress`] does not contain an entry for this transport.
#[error("No endpoint found for transport")]
NoEndpoint,
/// The endpoint string could not be parsed (malformed URL, invalid address).
#[error("Invalid endpoint format")]
InvalidEndpoint,
/// The target peer was never registered with this transport.
#[error("Peer not registered: {0}")]
PeerNotRegistered(InstanceId),
/// The transport has not been started yet (no runtime handle).
#[error("Transport not started")]
NotStarted,
/// No responders available for the peer (e.g. NATS request with no subscriber).
#[error("No responders for peer")]
NoResponders,
}
/// Error type specific to health check operations
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum HealthCheckError {
/// The peer was never registered with this transport.
#[error("Peer not registered with transport")]
PeerNotRegistered,
/// The transport has not been started yet.
#[error("Transport not started")]
TransportNotStarted,
/// The peer is registered but no connection has ever been established.
#[error("Connection never established to peer")]
NeverConnected,
/// An existing connection is unhealthy or the peer is unreachable.
#[error("Connection failed or peer unreachable")]
ConnectionFailed,
/// The health check exceeded the specified timeout.
#[error("Health check timed out")]
Timeout,
}
/// Shared shutdown coordinator for graceful multi-phase shutdown.
///
/// **Phases**:
/// 1. **Gate** — `begin_drain()` flips the draining flag; transports reject new inbound requests.
/// 2. **Drain** — `wait_for_drain()` blocks until all in-flight guards are dropped.
/// 3. **Teardown** — `teardown_token().cancel()` kills listeners and writer tasks.
///
/// Hot-path cost: a single `AtomicBool::load(Relaxed)` per frame to check `is_draining()`.
#[derive(Clone)]
pub struct ShutdownState {
inner: Arc<ShutdownStateInner>,
}
struct ShutdownStateInner {
draining: AtomicBool,
in_flight: AtomicUsize,
drain_complete: Notify,
teardown_token: CancellationToken,
}
impl ShutdownState {
/// Create a new shutdown state. Not draining, zero in-flight.
pub fn new() -> Self {
Self {
inner: Arc::new(ShutdownStateInner {
draining: AtomicBool::new(false),
in_flight: AtomicUsize::new(0),
drain_complete: Notify::new(),
teardown_token: CancellationToken::new(),
}),
}
}
/// Returns `true` if drain has been initiated (Phase 1).
///
/// Uses `Relaxed` ordering — safe for the hot-path gate check because
/// the flag is monotonic (false → true, never reset).
#[inline]
pub fn is_draining(&self) -> bool {
self.inner.draining.load(Ordering::Relaxed)
}
/// Begin Phase 1: flip the draining flag. Idempotent.
pub fn begin_drain(&self) {
self.inner.draining.store(true, Ordering::Release);
}
/// Acquire an in-flight guard. The guard increments the counter on creation
/// and decrements it on drop. Use this to track requests that are being processed.
///
/// Guards are still acquirable after `begin_drain()` — this is intentional
/// so that already-accepted work can be tracked.
pub fn acquire(&self) -> InFlightGuard {
self.inner.in_flight.fetch_add(1, Ordering::AcqRel);
InFlightGuard {
inner: self.inner.clone(),
}
}
/// Current number of in-flight requests. Primarily for testing/debugging.
pub fn in_flight_count(&self) -> usize {
self.inner.in_flight.load(Ordering::Acquire)
}
/// Wait until in-flight count reaches zero. Returns immediately if already zero.
pub async fn wait_for_drain(&self) {
loop {
if self.inner.in_flight.load(Ordering::Acquire) == 0 {
return;
}
self.inner.drain_complete.notified().await;
}
}
/// Get the Phase 3 teardown token. Cancel this to kill listeners/writers.
pub fn teardown_token(&self) -> &CancellationToken {
&self.inner.teardown_token
}
}
impl Default for ShutdownState {
fn default() -> Self {
Self::new()
}
}
/// RAII guard that decrements the in-flight counter on drop.
pub struct InFlightGuard {
inner: Arc<ShutdownStateInner>,
}
impl InFlightGuard {
/// Explicitly complete this guard (equivalent to dropping it).
pub fn complete(self) {
// Drop impl handles the decrement
}
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
let prev = self.inner.in_flight.fetch_sub(1, Ordering::AcqRel);
// If we just decremented to 0, notify waiters
if prev == 1 {
self.inner.drain_complete.notify_waiters();
}
}
}
/// Policy for how long to wait during the drain phase.
#[derive(Debug, Clone)]
pub enum ShutdownPolicy {
/// Wait indefinitely for all in-flight requests to complete.
WaitForever,
/// Wait up to the given duration, then force teardown.
Timeout(Duration),
}
/// Abstraction over a single message transport (TCP, HTTP, NATS, gRPC, UCX).
///
/// Implementations handle peer registration, message sending, listener lifecycle,
/// health checking, and graceful shutdown. The trait is object-safe so transports
/// can be stored as `Arc<dyn Transport>`.
pub trait Transport: Send + Sync {
/// Unique key identifying this transport (e.g. `"tcp"`, `"grpc"`).
fn key(&self) -> TransportKey;
/// The [`WorkerAddress`] fragment advertised by this transport.
fn address(&self) -> WorkerAddress;
/// Register a remote peer, extracting its endpoint from [`PeerInfo`].
fn register(&self, peer_info: PeerInfo) -> Result<(), TransportError>;
/// Sends an active message to the remote instance
fn send_message(
&self,
instance_id: InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: Arc<dyn TransportErrorHandler>,
);
/// Start the transport (bind listener, spawn tasks) for the given instance.
fn start(
&self,
instance_id: InstanceId,
channels: TransportAdapter,
rt: tokio::runtime::Handle,
) -> BoxFuture<'_, anyhow::Result<()>>;
/// Tear down the transport, cancelling all tasks and closing connections.
fn shutdown(&self);
/// Begin draining: reject new inbound requests while allowing responses.
///
/// Default implementation is a no-op. Transports that need per-frame
/// gating (e.g., unsubscribing from NATS subjects) should override this.
fn begin_drain(&self) {}
/// Check if a registered peer is reachable and healthy
///
/// Returns Ok(()) if peer responds to health check within timeout.
/// Different transports implement this differently:
/// - NATS: request/reply to health subject
/// - TCP: check existing connection or attempt new connection
/// - HTTP: HEAD request to health endpoint
/// - UCX: endpoint status check
///
/// # Errors
/// - `PeerNotRegistered`: Peer was never registered with this transport
/// - `TransportNotStarted`: Transport hasn't been started yet
/// - `NeverConnected`: Peer is registered but no connection has been established
/// - `ConnectionFailed`: Connection exists/existed but is currently unhealthy or unreachable
/// - `Timeout`: Health check took longer than the specified timeout
fn check_health(
&self,
instance_id: InstanceId,
timeout: Duration,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), HealthCheckError>> + Send + '_>,
>;
}
/// Callback trait invoked when a transport fails to deliver a message.
///
/// The original `header` and `payload` are returned so higher layers can
/// retry or log the failure.
pub trait TransportErrorHandler: Send + Sync {
/// Called when message delivery fails. Receives the original data and error description.
fn on_error(&self, header: Bytes, payload: Bytes, error: String);
}
/// Message type discriminator for routing frames to appropriate streams
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageType {
#[allow(missing_docs)]
Message = 0,
#[allow(missing_docs)]
Response = 1,
#[allow(missing_docs)]
Ack = 2,
#[allow(missing_docs)]
Event = 3,
/// Sent back to a peer when we are draining and cannot accept new messages.
/// The original request header is echoed back for correlation.
ShuttingDown = 4,
}
impl MessageType {
/// Try to convert a u8 to a MessageType
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(MessageType::Message),
1 => Some(MessageType::Response),
2 => Some(MessageType::Ack),
3 => Some(MessageType::Event),
4 => Some(MessageType::ShuttingDown),
_ => None,
}
}
/// Convert MessageType to u8
pub fn as_u8(self) -> u8 {
self as u8
}
}
/// Sender-side handle given to transports for routing inbound frames.
///
/// Each transport receives a clone of this adapter during [`Transport::start`]
/// and uses it to forward decoded `(header, payload)` pairs to the appropriate
/// stream based on [`MessageType`].
#[derive(Clone)]
pub struct TransportAdapter {
/// Channel for inbound [`MessageType::Message`] frames.
pub message_stream: flume::Sender<(Bytes, Bytes)>,
/// Channel for inbound [`MessageType::Response`] and [`MessageType::ShuttingDown`] frames.
pub response_stream: flume::Sender<(Bytes, Bytes)>,
/// Channel for inbound [`MessageType::Ack`] and [`MessageType::Event`] frames.
pub event_stream: flume::Sender<(Bytes, Bytes)>,
/// Shared shutdown coordinator for drain-aware routing.
pub shutdown_state: ShutdownState,
}
/// Receiver-side handle for consuming inbound frames from all transports.
///
/// Returned by [`make_channels`] alongside the corresponding [`TransportAdapter`].
/// Higher layers pull `(header, payload)` pairs from these channels.
pub struct DataStreams {
/// Receiver for inbound message frames.
pub message_stream: flume::Receiver<(Bytes, Bytes)>,
/// Receiver for inbound response and shutting-down frames.
pub response_stream: flume::Receiver<(Bytes, Bytes)>,
/// Receiver for inbound ack and event frames.
pub event_stream: flume::Receiver<(Bytes, Bytes)>,
/// Shared shutdown coordinator.
pub shutdown_state: ShutdownState,
}
type DataStreamTuple = (
flume::Receiver<(Bytes, Bytes)>,
flume::Receiver<(Bytes, Bytes)>,
flume::Receiver<(Bytes, Bytes)>,
);
impl DataStreams {
/// Destructure into the three raw receivers `(message, response, event)`.
pub fn into_parts(self) -> DataStreamTuple {
(self.message_stream, self.response_stream, self.event_stream)
}
/// Receive a message with an in-flight guard for drain tracking.
///
/// Returns `(header, payload, guard)`. The guard keeps the in-flight counter
/// incremented until it is dropped or `complete()` is called.
pub async fn recv_message_tracked(
&self,
) -> Result<(Bytes, Bytes, InFlightGuard), flume::RecvError> {
let (header, payload) = self.message_stream.recv_async().await?;
let guard = self.shutdown_state.acquire();
Ok((header, payload, guard))
}
}
/// Create a matched pair of [`TransportAdapter`] (sender) and [`DataStreams`] (receiver).
///
/// Both sides share the same [`ShutdownState`] so drain coordination is automatic.
pub fn make_channels() -> (TransportAdapter, DataStreams) {
let shutdown_state = ShutdownState::new();
let (message_tx, message_rx) = flume::unbounded();
let (response_tx, response_rx) = flume::unbounded();
let (event_tx, event_rx) = flume::unbounded();
(
TransportAdapter {
message_stream: message_tx,
response_stream: response_tx,
event_stream: event_tx,
shutdown_state: shutdown_state.clone(),
},
DataStreams {
message_stream: message_rx,
response_stream: response_rx,
event_stream: event_rx,
shutdown_state,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, timeout};
#[test]
fn test_shutdown_state_initial() {
let state = ShutdownState::new();
assert!(!state.is_draining());
assert_eq!(state.in_flight_count(), 0);
}
#[test]
fn test_begin_drain_flips_flag() {
let state = ShutdownState::new();
state.begin_drain();
assert!(state.is_draining());
}
#[test]
fn test_begin_drain_idempotent() {
let state = ShutdownState::new();
state.begin_drain();
state.begin_drain();
assert!(state.is_draining());
}
#[test]
fn test_acquire_increments_inflight() {
let state = ShutdownState::new();
let _g1 = state.acquire();
assert_eq!(state.in_flight_count(), 1);
let _g2 = state.acquire();
assert_eq!(state.in_flight_count(), 2);
}
#[test]
fn test_guard_drop_decrements_inflight() {
let state = ShutdownState::new();
let g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
drop(g);
assert_eq!(state.in_flight_count(), 0);
}
#[test]
fn test_guard_complete_decrements() {
let state = ShutdownState::new();
let g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
g.complete();
assert_eq!(state.in_flight_count(), 0);
}
#[tokio::test]
async fn test_wait_for_drain_immediate() {
let state = ShutdownState::new();
// Should complete immediately since in_flight is 0
timeout(Duration::from_millis(100), state.wait_for_drain())
.await
.expect("wait_for_drain should complete immediately when in_flight is 0");
}
#[tokio::test]
async fn test_wait_for_drain_blocks_then_completes() {
let state = ShutdownState::new();
let guard = state.acquire();
let state_clone = state.clone();
let handle = tokio::spawn(async move {
state_clone.wait_for_drain().await;
});
// Give the waiter time to park
sleep(Duration::from_millis(50)).await;
assert!(!handle.is_finished());
// Drop guard → should unblock
drop(guard);
timeout(Duration::from_millis(100), handle)
.await
.expect("should complete after guard drop")
.unwrap();
}
#[tokio::test]
async fn test_multiple_guards_concurrent() {
let state = ShutdownState::new();
let guards: Vec<_> = (0..10).map(|_| state.acquire()).collect();
assert_eq!(state.in_flight_count(), 10);
let state_clone = state.clone();
let handle = tokio::spawn(async move {
state_clone.wait_for_drain().await;
});
// Drop all guards
drop(guards);
timeout(Duration::from_millis(100), handle)
.await
.expect("should complete after all guards drop")
.unwrap();
assert_eq!(state.in_flight_count(), 0);
}
#[tokio::test]
async fn test_drain_with_zero_inflight() {
let state = ShutdownState::new();
state.begin_drain();
// Should complete immediately
timeout(Duration::from_millis(100), state.wait_for_drain())
.await
.expect("should complete immediately with zero in-flight");
}
#[test]
fn test_acquire_works_after_drain() {
let state = ShutdownState::new();
state.begin_drain();
let _g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
}
#[test]
fn test_guard_drop_during_panic() {
let state = ShutdownState::new();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _g = state.acquire();
assert_eq!(state.in_flight_count(), 1);
panic!("intentional panic");
}));
assert!(result.is_err());
// Guard's Drop should have fired even during unwind
assert_eq!(state.in_flight_count(), 0);
}
#[test]
fn test_shutting_down_from_u8() {
assert_eq!(MessageType::from_u8(4), Some(MessageType::ShuttingDown));
}
#[test]
fn test_shutting_down_as_u8() {
assert_eq!(MessageType::ShuttingDown.as_u8(), 4);
}
#[test]
fn test_unknown_message_type_still_none() {
assert_eq!(MessageType::from_u8(5), None);
assert_eq!(MessageType::from_u8(255), None);
}
#[test]
fn test_make_channels_includes_shutdown_state() {
let (adapter, streams) = make_channels();
// Both sides should share the same ShutdownState (via Arc)
assert!(!adapter.shutdown_state.is_draining());
assert!(!streams.shutdown_state.is_draining());
// Mutating one should be visible through the other
adapter.shutdown_state.begin_drain();
assert!(streams.shutdown_state.is_draining());
}
#[tokio::test]
async fn test_recv_message_tracked_returns_guard() {
let (adapter, streams) = make_channels();
// Send a message through the adapter
adapter
.message_stream
.send_async((
bytes::Bytes::from_static(b"hdr"),
bytes::Bytes::from_static(b"pay"),
))
.await
.unwrap();
// Receive with tracking
let (header, payload, guard) = streams.recv_message_tracked().await.unwrap();
assert_eq!(&header[..], b"hdr");
assert_eq!(&payload[..], b"pay");
assert_eq!(streams.shutdown_state.in_flight_count(), 1);
// Drop guard
drop(guard);
assert_eq!(streams.shutdown_state.in_flight_count(), 0);
}
#[test]
fn test_shutdown_state_clone_shares_inner() {
let s1 = ShutdownState::new();
let s2 = s1.clone();
s1.begin_drain();
assert!(s2.is_draining());
let _g = s1.acquire();
assert_eq!(s2.in_flight_count(), 1);
}
#[test]
fn test_teardown_token() {
let state = ShutdownState::new();
assert!(!state.teardown_token().is_cancelled());
state.teardown_token().cancel();
assert!(state.teardown_token().is_cancelled());
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment