Unverified Commit 9e9ca3e2 authored by Patrick's avatar Patrick Committed by GitHub
Browse files

feat(velo): add unix domain socket (#7197)


Signed-off-by: default avatarPatrick Riel <priel@nvidia.com>
Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent d5680ed9
...@@ -31,6 +31,9 @@ mod address; ...@@ -31,6 +31,9 @@ mod address;
pub mod tcp; pub mod tcp;
#[cfg(unix)]
pub mod uds;
// #[cfg(feature = "ucx")] // #[cfg(feature = "ucx")]
// pub mod ucx; // pub mod ucx;
......
...@@ -392,9 +392,9 @@ async fn connection_writer_task( ...@@ -392,9 +392,9 @@ async fn connection_writer_task(
instance_id: crate::InstanceId, instance_id: crate::InstanceId,
rx: flume::Receiver<SendTask>, rx: flume::Receiver<SendTask>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>, connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
_cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
let result = connection_writer_inner(addr, instance_id, &rx).await; let result = connection_writer_inner(addr, instance_id, &rx, &cancel_token).await;
// Always drain queued messages and notify their error handlers. // Always drain queued messages and notify their error handlers.
// //
...@@ -426,10 +426,14 @@ async fn connection_writer_inner( ...@@ -426,10 +426,14 @@ async fn connection_writer_inner(
addr: SocketAddr, addr: SocketAddr,
instance_id: crate::InstanceId, instance_id: crate::InstanceId,
rx: &flume::Receiver<SendTask>, rx: &flume::Receiver<SendTask>,
cancel_token: &CancellationToken,
) -> Result<()> { ) -> Result<()> {
debug!("Connecting to {}", addr); debug!("Connecting to {}", addr);
let mut stream = TcpStream::connect(addr).await.context("connect failed")?; let mut stream = tokio::select! {
_ = cancel_token.cancelled() => return Ok(()),
res = TcpStream::connect(addr) => res.context("connect failed")?,
};
if let Err(e) = stream.set_nodelay(true) { if let Err(e) = stream.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY: {}", e); warn!("Failed to set TCP_NODELAY: {}", e);
...@@ -450,7 +454,14 @@ async fn connection_writer_inner( ...@@ -450,7 +454,14 @@ async fn connection_writer_inner(
debug!("Connected to {}", addr); debug!("Connected to {}", addr);
while let Ok(msg) = rx.recv_async().await { loop {
let msg = tokio::select! {
_ = cancel_token.cancelled() => break,
res = rx.recv_async() => match res {
Ok(msg) => msg,
Err(_) => break,
},
};
if let Err(e) = if let Err(e) =
TcpFrameCodec::encode_frame(&mut stream, msg.msg_type, &msg.header, &msg.payload).await TcpFrameCodec::encode_frame(&mut stream, msg.msg_type, &msg.header, &msg.payload).await
{ {
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! UDS listener for ActiveMessage transport
//!
//! Mirrors `tcp/listener.rs` but uses `UnixListener`/`UnixStream`.
//! Reuses `TcpFrameCodec` for framing. Supports drain-aware frame handling
//! via `ShutdownState`.
use anyhow::{Context, Result};
use bytes::Bytes;
use futures::StreamExt;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::net::UnixListener as TokioUnixListener;
use tokio::net::UnixStream;
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};
use crate::{MessageType, ShutdownState, TransportAdapter, TransportErrorHandler};
use crate::tcp::TcpFrameCodec;
/// UDS listener for ActiveMessage transport
///
/// Accepts incoming Unix domain socket connections and routes decoded frames
/// to the appropriate transport streams. Supports graceful drain: during drain,
/// new `Message` frames are rejected with a `ShuttingDown` response while
/// `Response`/`Event`/`Ack` frames continue to flow.
pub struct UdsListener {
socket_path: PathBuf,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
}
/// UDS listener that has been bound to a socket path, ready to accept connections.
///
/// Created by [`UdsListener::bind`]. Holding this value proves the OS-level bind
/// succeeded, so callers can detect failures before spawning a task.
pub struct BoundUdsListener {
socket_path: PathBuf,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
listener: TokioUnixListener,
}
impl UdsListener {
/// Create a new builder for UdsListener
pub fn builder() -> UdsListenerBuilder {
UdsListenerBuilder::new()
}
/// Bind to the socket path and return a [`BoundUdsListener`] ready to serve.
///
/// `TokioUnixListener::bind` is synchronous, so this method is also
/// synchronous. Callers that need to propagate bind failures before spawning
/// a task should call `bind()` first, then spawn `bound.serve()`.
pub fn bind(self) -> Result<BoundUdsListener> {
let listener = TokioUnixListener::bind(&self.socket_path)
.with_context(|| format!("Failed to bind UDS listener to {:?}", self.socket_path))?;
info!("UDS listener bound to {:?}", self.socket_path);
Ok(BoundUdsListener {
socket_path: self.socket_path,
adapter: self.adapter,
error_handler: self.error_handler,
shutdown_state: self.shutdown_state,
listener,
})
}
/// Convenience shim: bind and serve in one call.
pub async fn serve(self) -> Result<()> {
self.bind()?.serve().await
}
/// Handle a single UDS connection
async fn handle_connection(
stream: UnixStream,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
) -> Result<()> {
debug!("Configuring UDS connection");
// Create framed stream with zero-copy codec (same as TCP)
let mut framed = Framed::new(stream, TcpFrameCodec::new());
let teardown_token = shutdown_state.teardown_token().clone();
debug!("UDS connection ready for frames");
loop {
tokio::select! {
frame_result = framed.next() => {
match frame_result {
Some(Ok((msg_type, header, payload))) => {
// During drain: reject new Message frames with ShuttingDown,
// but always pass through Response/Ack/Event frames.
if shutdown_state.is_draining() && msg_type == MessageType::Message {
debug!(
"Rejecting Message frame during drain (sending ShuttingDown)"
);
// Echo original header back for correlation, empty payload
if let Err(e) = TcpFrameCodec::encode_frame(
framed.get_mut(),
MessageType::ShuttingDown,
&header,
&[],
)
.await
{
warn!(
"Failed to send ShuttingDown frame: {}",
e
);
}
continue;
}
if let Err(e) = Self::route_frame(
msg_type,
header,
payload,
&adapter,
&error_handler,
)
.await
{
warn!(
"Failed to route {:?} frame from UDS: {}",
msg_type, e
);
}
}
Some(Err(e)) => {
error!("Frame decode error from UDS: {}", e);
break;
}
None => {
debug!("UDS connection closed gracefully");
break;
}
}
}
_ = teardown_token.cancelled() => {
debug!("UDS connection handler torn down");
break;
}
}
}
Ok(())
}
/// Route a decoded frame to the appropriate stream
async fn route_frame(
msg_type: MessageType,
header: Bytes,
payload: Bytes,
adapter: &TransportAdapter,
error_handler: &Arc<dyn TransportErrorHandler>,
) -> Result<()> {
let sender = match msg_type {
MessageType::Message => &adapter.message_stream,
MessageType::Response => &adapter.response_stream,
MessageType::Ack | MessageType::Event => &adapter.event_stream,
MessageType::ShuttingDown => {
// ShuttingDown is an outbound-only frame type; receiving it here
// means a remote peer rejected our request. Route to the response
// stream so higher layers can handle the rejection via correlation.
&adapter.response_stream
}
};
match sender.send_async((header, payload)).await {
Ok(_) => Ok(()),
Err(e) => {
error_handler.on_error(e.0.0, e.0.1, format!("Failed to route {:?}", msg_type));
Err(anyhow::anyhow!("Failed to send to stream"))
}
}
}
}
impl BoundUdsListener {
/// Accept connections until the teardown token is cancelled.
pub async fn serve(self) -> Result<()> {
let teardown_token = self.shutdown_state.teardown_token().clone();
loop {
tokio::select! {
accept_result = self.listener.accept() => {
match accept_result {
Ok((stream, _addr)) => {
debug!("Accepted UDS connection");
let adapter = self.adapter.clone();
let error_handler = self.error_handler.clone();
let shutdown_state = self.shutdown_state.clone();
tokio::spawn(async move {
if let Err(e) = UdsListener::handle_connection(
stream,
adapter,
error_handler,
shutdown_state,
)
.await
{
warn!("Error handling UDS connection: {}", e);
}
});
}
Err(e) => {
error!("Failed to accept UDS connection: {}", e);
}
}
}
_ = teardown_token.cancelled() => {
info!("UDS listener shutting down (teardown)");
break;
}
}
}
// Clean up socket file
std::fs::remove_file(&self.socket_path).ok();
Ok(())
}
}
/// Builder for UdsListener
pub struct UdsListenerBuilder {
socket_path: Option<PathBuf>,
adapter: Option<TransportAdapter>,
error_handler: Option<Arc<dyn TransportErrorHandler>>,
shutdown_state: Option<ShutdownState>,
}
impl UdsListenerBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
socket_path: None,
adapter: None,
error_handler: None,
shutdown_state: None,
}
}
/// Set the socket path
pub fn socket_path(mut self, path: PathBuf) -> Self {
self.socket_path = Some(path);
self
}
/// Set the transport adapter
pub fn adapter(mut self, adapter: TransportAdapter) -> Self {
self.adapter = Some(adapter);
self
}
/// Set the error handler
pub fn error_handler(mut self, handler: Arc<dyn TransportErrorHandler>) -> Self {
self.error_handler = Some(handler);
self
}
/// Set the shutdown state for graceful drain coordination
pub fn shutdown_state(mut self, state: ShutdownState) -> Self {
self.shutdown_state = Some(state);
self
}
/// Build the UdsListener
pub fn build(self) -> Result<UdsListener> {
let socket_path = self
.socket_path
.ok_or_else(|| anyhow::anyhow!("socket_path is required"))?;
let adapter = self
.adapter
.ok_or_else(|| anyhow::anyhow!("adapter is required"))?;
let error_handler = self
.error_handler
.ok_or_else(|| anyhow::anyhow!("error_handler is required"))?;
let shutdown_state = self.shutdown_state.unwrap_or_default();
Ok(UdsListener {
socket_path,
adapter,
error_handler,
shutdown_state,
})
}
}
impl Default for UdsListenerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::make_channels;
struct TestErrorHandler;
impl TransportErrorHandler for TestErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
eprintln!("Test error handler: {}", error);
}
}
#[test]
fn test_builder_requires_fields() {
let result = UdsListener::builder().build();
assert!(result.is_err());
}
#[tokio::test]
async fn test_builder_with_all_fields() {
let (adapter, _streams) = make_channels();
let error_handler = Arc::new(TestErrorHandler);
let result = UdsListener::builder()
.socket_path(PathBuf::from("/tmp/test-uds-listener.sock"))
.adapter(adapter)
.error_handler(error_handler)
.build();
assert!(result.is_ok());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Unix Domain Socket (UDS) Transport Module
//!
//! This module provides a UDS transport implementation that mirrors the TCP transport
//! but uses Unix domain sockets instead of TCP connections. It reuses the same
//! zero-copy frame codec (`TcpFrameCodec`) since the framing protocol is transport-agnostic.
//!
//! Key differences from TCP:
//! - Uses `PathBuf` instead of `SocketAddr`
//! - Uses `UnixStream`/`UnixListener` instead of `TcpStream`/`TcpListener`
//! - No TCP-specific options (nodelay, keepalive, CPU pinning)
//! - Endpoint format: `uds:///path/to/socket`
//!
//! This transport is ideal for same-host communication (e.g., daemon-to-container via
//! bind-mounted sockets), avoiding the overhead of the TCP/IP stack entirely.
mod listener;
mod transport;
pub use listener::{UdsListener, UdsListenerBuilder};
pub use transport::{UdsTransport, UdsTransportBuilder};
This diff is collapsed.
...@@ -24,6 +24,9 @@ use velo_transports::{ ...@@ -24,6 +24,9 @@ use velo_transports::{
tcp::{TcpTransport, TcpTransportBuilder}, tcp::{TcpTransport, TcpTransportBuilder},
}; };
#[cfg(unix)]
use velo_transports::uds::{UdsTransport, UdsTransportBuilder};
use std::sync::Once; use std::sync::Once;
use tracing_subscriber::FmtSubscriber; use tracing_subscriber::FmtSubscriber;
...@@ -252,6 +255,24 @@ impl TestTransportHandle<TcpTransport> { ...@@ -252,6 +255,24 @@ impl TestTransportHandle<TcpTransport> {
} }
} }
// UDS-specific convenience constructors
#[cfg(unix)]
impl TestTransportHandle<UdsTransport> {
/// Create a new UDS transport using a temp directory socket path
pub async fn new_uds() -> anyhow::Result<Self> {
Self::with_factory(|| {
let dir = std::env::temp_dir().join(format!(
"velo-uds-test-{}",
velo_transports::InstanceId::new_v4()
));
std::fs::create_dir_all(&dir)?;
let socket_path = dir.join("transport.sock");
UdsTransportBuilder::new().socket_path(&socket_path).build()
})
.await
}
}
// // UCX-specific convenience constructors // // UCX-specific convenience constructors
// #[cfg(feature = "ucx")] // #[cfg(feature = "ucx")]
// impl TestTransportHandle<UcxTransport> { // impl TestTransportHandle<UcxTransport> {
...@@ -416,6 +437,24 @@ impl TestCluster<TcpTransport> { ...@@ -416,6 +437,24 @@ impl TestCluster<TcpTransport> {
} }
} }
// UDS-specific convenience constructor
#[cfg(unix)]
impl TestCluster<UdsTransport> {
/// Create a new UDS test cluster with the specified number of transports
pub async fn new_uds(size: usize) -> anyhow::Result<Self> {
Self::with_factory(size, || {
let dir = std::env::temp_dir().join(format!(
"velo-uds-test-{}",
velo_transports::InstanceId::new_v4()
));
std::fs::create_dir_all(&dir)?;
let socket_path = dir.join("transport.sock");
UdsTransportBuilder::new().socket_path(&socket_path).build()
})
.await
}
}
// // HTTP-specific convenience constructor // // HTTP-specific convenience constructor
// #[cfg(feature = "http")] // #[cfg(feature = "http")]
// impl TestCluster<HttpTransport> { // impl TestCluster<HttpTransport> {
...@@ -535,6 +574,23 @@ impl TransportFactory for TcpFactory { ...@@ -535,6 +574,23 @@ impl TransportFactory for TcpFactory {
} }
} }
/// UDS transport factory
#[cfg(unix)]
pub struct UdsFactory;
#[cfg(unix)]
impl TransportFactory for UdsFactory {
type Transport = UdsTransport;
async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
TestTransportHandle::new_uds().await
}
async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
TestCluster::new_uds(size).await
}
}
// /// UCX transport factory // /// UCX transport factory
// #[cfg(feature = "ucx")] // #[cfg(feature = "ucx")]
// pub struct UcxFactory; // pub struct UcxFactory;
......
...@@ -359,7 +359,10 @@ async fn test_connection_writer_exits_on_teardown() { ...@@ -359,7 +359,10 @@ async fn test_connection_writer_exits_on_teardown() {
// Give writer tasks time to exit // Give writer tasks time to exit
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
// Sending should now fail (error handler gets invoked, not a panic) // Sending after shutdown: the cancel_token is already cancelled, so any new
// writer task returns immediately without connecting. The error handler must
// be invoked and the message must not arrive on handle_b.
handle_a.error_handler.clear();
handle_a.send( handle_a.send(
handle_b.instance_id, handle_b.instance_id,
b"should-fail".to_vec(), b"should-fail".to_vec(),
...@@ -367,11 +370,25 @@ async fn test_connection_writer_exits_on_teardown() { ...@@ -367,11 +370,25 @@ async fn test_connection_writer_exits_on_teardown() {
MessageType::Message, MessageType::Message,
); );
// Give time for async error path // Give time for the async error path to complete
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// The message either goes to error handler or is silently dropped // Error handler must have been invoked for the failed send
// (connection cleared during shutdown). Just verify no panic occurred. assert!(
handle_a.error_handler.error_count() >= 1,
"error handler should be invoked for post-shutdown send"
);
// The message must not have been delivered to handle_b
let not_delivered = timeout(
Duration::from_millis(100),
handle_b.streams.message_stream.recv_async(),
)
.await;
assert!(
not_delivered.is_err(),
"post-shutdown message must not arrive at handle_b"
);
handle_b.streams.shutdown_state.teardown_token().cancel(); handle_b.streams.shutdown_state.teardown_token().cancel();
} }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for UDS transport
#![cfg(unix)]
mod common;
use common::{UdsFactory, scenarios};
#[tokio::test]
async fn test_single_message_round_trip() {
scenarios::single_message_round_trip::<UdsFactory>().await;
}
#[tokio::test]
async fn test_bidirectional_messaging() {
scenarios::bidirectional_messaging::<UdsFactory>().await;
}
#[tokio::test]
async fn test_multiple_messages_same_connection() {
scenarios::multiple_messages_same_connection::<UdsFactory>().await;
}
#[tokio::test]
async fn test_response_message_type() {
scenarios::response_message_type::<UdsFactory>().await;
}
#[tokio::test]
async fn test_event_message_type() {
scenarios::event_message_type::<UdsFactory>().await;
}
#[tokio::test]
async fn test_ack_message_type() {
scenarios::ack_message_type::<UdsFactory>().await;
}
#[tokio::test]
async fn test_mixed_message_types() {
scenarios::mixed_message_types::<UdsFactory>().await;
}
#[tokio::test]
async fn test_large_payload() {
scenarios::large_payload::<UdsFactory>().await;
}
#[tokio::test]
async fn test_empty_header_and_payload() {
scenarios::empty_header_and_payload::<UdsFactory>().await;
}
#[tokio::test]
async fn test_cluster_mesh_communication() {
scenarios::cluster_mesh_communication::<UdsFactory>().await;
}
#[tokio::test]
async fn test_concurrent_senders() {
scenarios::concurrent_senders::<UdsFactory>().await;
}
#[tokio::test]
async fn test_send_to_unregistered_peer() {
scenarios::send_to_unregistered_peer::<UdsFactory>().await;
}
#[tokio::test]
async fn test_connection_reuse() {
scenarios::connection_reuse::<UdsFactory>().await;
}
#[tokio::test]
async fn test_graceful_shutdown() {
scenarios::graceful_shutdown::<UdsFactory>().await;
}
#[tokio::test]
async fn test_high_throughput() {
scenarios::high_throughput::<UdsFactory>().await;
}
#[tokio::test]
async fn test_zero_copy_efficiency() {
scenarios::zero_copy_efficiency::<UdsFactory>().await;
}
#[tokio::test]
async fn test_drain_rejects_messages() {
scenarios::drain_rejects_messages::<UdsFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_responses() {
scenarios::drain_accepts_responses::<UdsFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_events() {
scenarios::drain_accepts_events::<UdsFactory>().await;
}
#[tokio::test]
async fn test_health_during_drain() {
scenarios::health_during_drain::<UdsFactory>().await;
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for UDS graceful shutdown
//!
//! These tests verify the 3-phase shutdown behavior:
//! 1. Gate: new Message frames are rejected with ShuttingDown
//! 2. Drain: in-flight work completes, responses/events still flow
//! 3. Teardown: listener and writer tasks exit
#![cfg(unix)]
mod common;
use bytes::Bytes;
use std::time::Duration;
use tokio::time::{sleep, timeout};
use velo_transports::tcp::TcpFrameCodec;
use velo_transports::uds::UdsTransport;
use velo_transports::{MessageType, Transport};
use common::TestTransportHandle;
/// Get the socket path from a UDS transport by parsing its WorkerAddress.
fn get_socket_path(handle: &TestTransportHandle<UdsTransport>) -> std::path::PathBuf {
let addr = handle.transport.address();
let key = handle.transport.key();
let endpoint = addr.get_entry(&key).unwrap().unwrap();
let s = std::str::from_utf8(&endpoint).unwrap();
let s = s.strip_prefix("uds://").unwrap_or(s);
std::path::PathBuf::from(s)
}
/// Helper: connect a raw UDS client to the transport's socket and send a frame.
async fn connect_and_send_frame(
socket_path: &std::path::Path,
msg_type: MessageType,
header: &[u8],
payload: &[u8],
) -> tokio::net::UnixStream {
let mut stream = tokio::net::UnixStream::connect(socket_path).await.unwrap();
TcpFrameCodec::encode_frame(&mut stream, msg_type, header, payload)
.await
.unwrap();
stream
}
/// Helper: read one frame from a raw UDS stream.
async fn read_one_frame(stream: &mut tokio::net::UnixStream) -> (MessageType, Bytes, Bytes) {
use futures::StreamExt;
use tokio_util::codec::Framed;
let mut framed = Framed::new(stream, TcpFrameCodec::new());
framed.next().await.unwrap().unwrap()
}
// --- Test: Drain rejects Message frames ---
#[tokio::test]
async fn test_uds_drain_rejects_messages() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
// Begin drain
handle.streams.shutdown_state.begin_drain();
// Give listener time to be ready
sleep(Duration::from_millis(50)).await;
// Connect and send a Message frame
let mut stream = connect_and_send_frame(
&socket_path,
MessageType::Message,
b"req-header",
b"req-payload",
)
.await;
// Should get ShuttingDown back
let (msg_type, header, payload) = read_one_frame(&mut stream).await;
assert_eq!(msg_type, MessageType::ShuttingDown);
assert_eq!(&header[..], b"req-header"); // Original header echoed back
assert_eq!(payload.len(), 0); // Empty payload
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: Drain accepts Response frames ---
#[tokio::test]
async fn test_uds_drain_accepts_responses() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
// Begin drain
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Connect and send a Response frame
connect_and_send_frame(
&socket_path,
MessageType::Response,
b"resp-header",
b"resp-payload",
)
.await;
// Should arrive on the response stream
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"resp-header");
assert_eq!(&payload[..], b"resp-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: Drain accepts Event frames ---
#[tokio::test]
async fn test_uds_drain_accepts_events() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
connect_and_send_frame(
&socket_path,
MessageType::Event,
b"evt-header",
b"evt-payload",
)
.await;
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.event_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"evt-header");
assert_eq!(&payload[..], b"evt-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: New connection during drain still accepts responses ---
#[tokio::test]
async fn test_uds_new_connection_during_drain() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
// Begin drain BEFORE connecting
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Establish a NEW connection after drain starts
connect_and_send_frame(
&socket_path,
MessageType::Response,
b"new-resp",
b"new-payload",
)
.await;
// Should arrive on the response stream
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"new-resp");
assert_eq!(&payload[..], b"new-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: Full graceful shutdown lifecycle ---
#[tokio::test]
async fn test_uds_graceful_shutdown_lifecycle() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
// Verify normal operation: send a message, receive it
connect_and_send_frame(
&socket_path,
MessageType::Message,
b"normal-msg",
b"normal-pay",
)
.await;
let (header, _payload) = timeout(
Duration::from_secs(2),
handle.streams.message_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"normal-msg");
// Acquire an InFlightGuard (simulate in-progress request)
let guard = handle.streams.shutdown_state.acquire();
assert_eq!(handle.streams.shutdown_state.in_flight_count(), 1);
// Begin drain (Phase 1)
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Verify new messages are rejected
let mut stream =
connect_and_send_frame(&socket_path, MessageType::Message, b"reject-me", b"").await;
let (msg_type, _, _) = read_one_frame(&mut stream).await;
assert_eq!(msg_type, MessageType::ShuttingDown);
// Verify responses still flow
connect_and_send_frame(&socket_path, MessageType::Response, b"still-ok", b"data").await;
let (header, _) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"still-ok");
// Spawn graceful_shutdown in background (will block on drain since guard is held)
let shutdown_state = handle.streams.shutdown_state.clone();
let shutdown_handle = tokio::spawn(async move {
// Phase 2: wait for drain
shutdown_state.wait_for_drain().await;
// Phase 3: teardown
shutdown_state.teardown_token().cancel();
});
// Verify shutdown hasn't completed yet (guard still held)
sleep(Duration::from_millis(100)).await;
assert!(!shutdown_handle.is_finished());
// Drop guard -> drain completes -> teardown fires
drop(guard);
timeout(Duration::from_secs(2), shutdown_handle)
.await
.expect("shutdown should complete")
.unwrap();
assert!(
handle
.streams
.shutdown_state
.teardown_token()
.is_cancelled()
);
}
// --- Test: Shutdown timeout forces teardown ---
#[tokio::test]
async fn test_uds_shutdown_timeout_forces_teardown() {
let handle = TestTransportHandle::new_uds().await.unwrap();
// Acquire guard and hold it
let _guard = handle.streams.shutdown_state.acquire();
let shutdown_state = handle.streams.shutdown_state.clone();
let shutdown_handle = tokio::spawn(async move {
shutdown_state.begin_drain();
// Phase 2: wait with short timeout
let _ =
tokio::time::timeout(Duration::from_millis(100), shutdown_state.wait_for_drain()).await;
// Phase 3: teardown (forced, guard still held)
shutdown_state.teardown_token().cancel();
});
timeout(Duration::from_secs(2), shutdown_handle)
.await
.expect("shutdown should complete via timeout")
.unwrap();
// Teardown should have fired even though guard is held
assert!(
handle
.streams
.shutdown_state
.teardown_token()
.is_cancelled()
);
// Guard is still held (not a problem — teardown was forced)
assert_eq!(handle.streams.shutdown_state.in_flight_count(), 1);
}
// --- Test: Outbound sends during drain ---
#[tokio::test]
async fn test_uds_outbound_sends_during_drain() {
// Create two transports and register them as peers
let handle_a = TestTransportHandle::new_uds().await.unwrap();
let handle_b = TestTransportHandle::new_uds().await.unwrap();
handle_a.register_peer(&handle_b).unwrap();
handle_b.register_peer(&handle_a).unwrap();
// Begin drain on transport A
handle_a.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Send a Response from A to B (outbound sends should work during drain)
handle_a.send(
handle_b.instance_id,
b"response-hdr".to_vec(),
b"response-pay".to_vec(),
MessageType::Response,
);
// B should receive the response
let (header, payload) = timeout(
Duration::from_secs(2),
handle_b.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"response-hdr");
assert_eq!(&payload[..], b"response-pay");
handle_a.streams.shutdown_state.teardown_token().cancel();
handle_b.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: Connection writer exits on teardown ---
#[tokio::test]
async fn test_uds_connection_writer_exits_on_teardown() {
let handle_a = TestTransportHandle::new_uds().await.unwrap();
let handle_b = TestTransportHandle::new_uds().await.unwrap();
handle_a.register_peer(&handle_b).unwrap();
// Send a message to establish the connection writer task
handle_a.send(
handle_b.instance_id,
b"setup".to_vec(),
b"data".to_vec(),
MessageType::Message,
);
// Wait for it to arrive
timeout(
Duration::from_secs(2),
handle_b.streams.message_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
// Shutdown transport A
handle_a.transport.shutdown();
// Give writer tasks time to exit
sleep(Duration::from_millis(200)).await;
// Sending after shutdown: the cancel_token is already cancelled, so any new
// writer task returns immediately without connecting. The error handler must
// be invoked and the message must not arrive on handle_b.
handle_a.error_handler.clear();
handle_a.send(
handle_b.instance_id,
b"should-fail".to_vec(),
b"data".to_vec(),
MessageType::Message,
);
// Give time for the async error path to complete
sleep(Duration::from_millis(100)).await;
// Error handler must have been invoked for the failed send
assert!(
handle_a.error_handler.error_count() >= 1,
"error handler should be invoked for post-shutdown send"
);
// The message must not have been delivered to handle_b
let not_delivered = timeout(
Duration::from_millis(100),
handle_b.streams.message_stream.recv_async(),
)
.await;
assert!(
not_delivered.is_err(),
"post-shutdown message must not arrive at handle_b"
);
handle_b.streams.shutdown_state.teardown_token().cancel();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment