"tests/vscode:/vscode.git/clone" did not exist on "74a566761be978a3ae17026675d9d42acd174fec"
Unverified Commit 9e9ca3e2 authored by Patrick's avatar Patrick Committed by GitHub
Browse files

feat(velo): add unix domain socket (#7197)


Signed-off-by: default avatarPatrick Riel <priel@nvidia.com>
Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent d5680ed9
...@@ -31,6 +31,9 @@ mod address; ...@@ -31,6 +31,9 @@ mod address;
pub mod tcp; pub mod tcp;
#[cfg(unix)]
pub mod uds;
// #[cfg(feature = "ucx")] // #[cfg(feature = "ucx")]
// pub mod ucx; // pub mod ucx;
......
...@@ -392,9 +392,9 @@ async fn connection_writer_task( ...@@ -392,9 +392,9 @@ async fn connection_writer_task(
instance_id: crate::InstanceId, instance_id: crate::InstanceId,
rx: flume::Receiver<SendTask>, rx: flume::Receiver<SendTask>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>, connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
_cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
let result = connection_writer_inner(addr, instance_id, &rx).await; let result = connection_writer_inner(addr, instance_id, &rx, &cancel_token).await;
// Always drain queued messages and notify their error handlers. // Always drain queued messages and notify their error handlers.
// //
...@@ -426,10 +426,14 @@ async fn connection_writer_inner( ...@@ -426,10 +426,14 @@ async fn connection_writer_inner(
addr: SocketAddr, addr: SocketAddr,
instance_id: crate::InstanceId, instance_id: crate::InstanceId,
rx: &flume::Receiver<SendTask>, rx: &flume::Receiver<SendTask>,
cancel_token: &CancellationToken,
) -> Result<()> { ) -> Result<()> {
debug!("Connecting to {}", addr); debug!("Connecting to {}", addr);
let mut stream = TcpStream::connect(addr).await.context("connect failed")?; let mut stream = tokio::select! {
_ = cancel_token.cancelled() => return Ok(()),
res = TcpStream::connect(addr) => res.context("connect failed")?,
};
if let Err(e) = stream.set_nodelay(true) { if let Err(e) = stream.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY: {}", e); warn!("Failed to set TCP_NODELAY: {}", e);
...@@ -450,7 +454,14 @@ async fn connection_writer_inner( ...@@ -450,7 +454,14 @@ async fn connection_writer_inner(
debug!("Connected to {}", addr); debug!("Connected to {}", addr);
while let Ok(msg) = rx.recv_async().await { loop {
let msg = tokio::select! {
_ = cancel_token.cancelled() => break,
res = rx.recv_async() => match res {
Ok(msg) => msg,
Err(_) => break,
},
};
if let Err(e) = if let Err(e) =
TcpFrameCodec::encode_frame(&mut stream, msg.msg_type, &msg.header, &msg.payload).await TcpFrameCodec::encode_frame(&mut stream, msg.msg_type, &msg.header, &msg.payload).await
{ {
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! UDS listener for ActiveMessage transport
//!
//! Mirrors `tcp/listener.rs` but uses `UnixListener`/`UnixStream`.
//! Reuses `TcpFrameCodec` for framing. Supports drain-aware frame handling
//! via `ShutdownState`.
use anyhow::{Context, Result};
use bytes::Bytes;
use futures::StreamExt;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::net::UnixListener as TokioUnixListener;
use tokio::net::UnixStream;
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};
use crate::{MessageType, ShutdownState, TransportAdapter, TransportErrorHandler};
use crate::tcp::TcpFrameCodec;
/// UDS listener for ActiveMessage transport
///
/// Accepts incoming Unix domain socket connections and routes decoded frames
/// to the appropriate transport streams. Supports graceful drain: during drain,
/// new `Message` frames are rejected with a `ShuttingDown` response while
/// `Response`/`Event`/`Ack` frames continue to flow.
pub struct UdsListener {
socket_path: PathBuf,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
}
/// UDS listener that has been bound to a socket path, ready to accept connections.
///
/// Created by [`UdsListener::bind`]. Holding this value proves the OS-level bind
/// succeeded, so callers can detect failures before spawning a task.
pub struct BoundUdsListener {
socket_path: PathBuf,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
listener: TokioUnixListener,
}
impl UdsListener {
/// Create a new builder for UdsListener
pub fn builder() -> UdsListenerBuilder {
UdsListenerBuilder::new()
}
/// Bind to the socket path and return a [`BoundUdsListener`] ready to serve.
///
/// `TokioUnixListener::bind` is synchronous, so this method is also
/// synchronous. Callers that need to propagate bind failures before spawning
/// a task should call `bind()` first, then spawn `bound.serve()`.
pub fn bind(self) -> Result<BoundUdsListener> {
let listener = TokioUnixListener::bind(&self.socket_path)
.with_context(|| format!("Failed to bind UDS listener to {:?}", self.socket_path))?;
info!("UDS listener bound to {:?}", self.socket_path);
Ok(BoundUdsListener {
socket_path: self.socket_path,
adapter: self.adapter,
error_handler: self.error_handler,
shutdown_state: self.shutdown_state,
listener,
})
}
/// Convenience shim: bind and serve in one call.
pub async fn serve(self) -> Result<()> {
self.bind()?.serve().await
}
/// Handle a single UDS connection
async fn handle_connection(
stream: UnixStream,
adapter: TransportAdapter,
error_handler: Arc<dyn TransportErrorHandler>,
shutdown_state: ShutdownState,
) -> Result<()> {
debug!("Configuring UDS connection");
// Create framed stream with zero-copy codec (same as TCP)
let mut framed = Framed::new(stream, TcpFrameCodec::new());
let teardown_token = shutdown_state.teardown_token().clone();
debug!("UDS connection ready for frames");
loop {
tokio::select! {
frame_result = framed.next() => {
match frame_result {
Some(Ok((msg_type, header, payload))) => {
// During drain: reject new Message frames with ShuttingDown,
// but always pass through Response/Ack/Event frames.
if shutdown_state.is_draining() && msg_type == MessageType::Message {
debug!(
"Rejecting Message frame during drain (sending ShuttingDown)"
);
// Echo original header back for correlation, empty payload
if let Err(e) = TcpFrameCodec::encode_frame(
framed.get_mut(),
MessageType::ShuttingDown,
&header,
&[],
)
.await
{
warn!(
"Failed to send ShuttingDown frame: {}",
e
);
}
continue;
}
if let Err(e) = Self::route_frame(
msg_type,
header,
payload,
&adapter,
&error_handler,
)
.await
{
warn!(
"Failed to route {:?} frame from UDS: {}",
msg_type, e
);
}
}
Some(Err(e)) => {
error!("Frame decode error from UDS: {}", e);
break;
}
None => {
debug!("UDS connection closed gracefully");
break;
}
}
}
_ = teardown_token.cancelled() => {
debug!("UDS connection handler torn down");
break;
}
}
}
Ok(())
}
/// Route a decoded frame to the appropriate stream
async fn route_frame(
msg_type: MessageType,
header: Bytes,
payload: Bytes,
adapter: &TransportAdapter,
error_handler: &Arc<dyn TransportErrorHandler>,
) -> Result<()> {
let sender = match msg_type {
MessageType::Message => &adapter.message_stream,
MessageType::Response => &adapter.response_stream,
MessageType::Ack | MessageType::Event => &adapter.event_stream,
MessageType::ShuttingDown => {
// ShuttingDown is an outbound-only frame type; receiving it here
// means a remote peer rejected our request. Route to the response
// stream so higher layers can handle the rejection via correlation.
&adapter.response_stream
}
};
match sender.send_async((header, payload)).await {
Ok(_) => Ok(()),
Err(e) => {
error_handler.on_error(e.0.0, e.0.1, format!("Failed to route {:?}", msg_type));
Err(anyhow::anyhow!("Failed to send to stream"))
}
}
}
}
impl BoundUdsListener {
/// Accept connections until the teardown token is cancelled.
pub async fn serve(self) -> Result<()> {
let teardown_token = self.shutdown_state.teardown_token().clone();
loop {
tokio::select! {
accept_result = self.listener.accept() => {
match accept_result {
Ok((stream, _addr)) => {
debug!("Accepted UDS connection");
let adapter = self.adapter.clone();
let error_handler = self.error_handler.clone();
let shutdown_state = self.shutdown_state.clone();
tokio::spawn(async move {
if let Err(e) = UdsListener::handle_connection(
stream,
adapter,
error_handler,
shutdown_state,
)
.await
{
warn!("Error handling UDS connection: {}", e);
}
});
}
Err(e) => {
error!("Failed to accept UDS connection: {}", e);
}
}
}
_ = teardown_token.cancelled() => {
info!("UDS listener shutting down (teardown)");
break;
}
}
}
// Clean up socket file
std::fs::remove_file(&self.socket_path).ok();
Ok(())
}
}
/// Builder for UdsListener
pub struct UdsListenerBuilder {
socket_path: Option<PathBuf>,
adapter: Option<TransportAdapter>,
error_handler: Option<Arc<dyn TransportErrorHandler>>,
shutdown_state: Option<ShutdownState>,
}
impl UdsListenerBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
socket_path: None,
adapter: None,
error_handler: None,
shutdown_state: None,
}
}
/// Set the socket path
pub fn socket_path(mut self, path: PathBuf) -> Self {
self.socket_path = Some(path);
self
}
/// Set the transport adapter
pub fn adapter(mut self, adapter: TransportAdapter) -> Self {
self.adapter = Some(adapter);
self
}
/// Set the error handler
pub fn error_handler(mut self, handler: Arc<dyn TransportErrorHandler>) -> Self {
self.error_handler = Some(handler);
self
}
/// Set the shutdown state for graceful drain coordination
pub fn shutdown_state(mut self, state: ShutdownState) -> Self {
self.shutdown_state = Some(state);
self
}
/// Build the UdsListener
pub fn build(self) -> Result<UdsListener> {
let socket_path = self
.socket_path
.ok_or_else(|| anyhow::anyhow!("socket_path is required"))?;
let adapter = self
.adapter
.ok_or_else(|| anyhow::anyhow!("adapter is required"))?;
let error_handler = self
.error_handler
.ok_or_else(|| anyhow::anyhow!("error_handler is required"))?;
let shutdown_state = self.shutdown_state.unwrap_or_default();
Ok(UdsListener {
socket_path,
adapter,
error_handler,
shutdown_state,
})
}
}
impl Default for UdsListenerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::make_channels;
struct TestErrorHandler;
impl TransportErrorHandler for TestErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
eprintln!("Test error handler: {}", error);
}
}
#[test]
fn test_builder_requires_fields() {
let result = UdsListener::builder().build();
assert!(result.is_err());
}
#[tokio::test]
async fn test_builder_with_all_fields() {
let (adapter, _streams) = make_channels();
let error_handler = Arc::new(TestErrorHandler);
let result = UdsListener::builder()
.socket_path(PathBuf::from("/tmp/test-uds-listener.sock"))
.adapter(adapter)
.error_handler(error_handler)
.build();
assert!(result.is_ok());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Unix Domain Socket (UDS) Transport Module
//!
//! This module provides a UDS transport implementation that mirrors the TCP transport
//! but uses Unix domain sockets instead of TCP connections. It reuses the same
//! zero-copy frame codec (`TcpFrameCodec`) since the framing protocol is transport-agnostic.
//!
//! Key differences from TCP:
//! - Uses `PathBuf` instead of `SocketAddr`
//! - Uses `UnixStream`/`UnixListener` instead of `TcpStream`/`TcpListener`
//! - No TCP-specific options (nodelay, keepalive, CPU pinning)
//! - Endpoint format: `uds:///path/to/socket`
//!
//! This transport is ideal for same-host communication (e.g., daemon-to-container via
//! bind-mounted sockets), avoiding the overhead of the TCP/IP stack entirely.
mod listener;
mod transport;
pub use listener::{UdsListener, UdsListenerBuilder};
pub use transport::{UdsTransport, UdsTransportBuilder};
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! UDS transport implementation
//!
//! Structural mirror of the TCP transport (`tcp/transport.rs`), replacing
//! `TcpStream`/`TcpListener` with `UnixStream`/`UnixListener`.
//! Reuses `TcpFrameCodec` for framing since it operates on any `AsyncRead + AsyncWrite`.
use anyhow::{Context, Result};
use bytes::Bytes;
use dashmap::DashMap;
use std::os::unix::fs::FileTypeExt;
use std::path::{Path, PathBuf};
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use tokio::net::UnixStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use crate::transport::{HealthCheckError, ShutdownState, TransportError, TransportErrorHandler};
use crate::{MessageType, PeerInfo, Transport, TransportAdapter, TransportKey, WorkerAddress};
use super::listener::UdsListener;
use crate::tcp::TcpFrameCodec;
/// UDS transport with lock-free concurrent access
///
/// Mirrors `TcpTransport` but uses Unix domain sockets.
pub struct UdsTransport {
key: TransportKey,
socket_path: PathBuf,
local_address: WorkerAddress,
// Shared mutable state with DashMap (lock-free)
peers: Arc<DashMap<crate::InstanceId, PathBuf>>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
// Runtime handle for spawning tasks
runtime: OnceLock<tokio::runtime::Handle>,
// Shutdown coordination
cancel_token: CancellationToken,
shutdown_state: OnceLock<ShutdownState>,
// Send channel capacity for backpressure
channel_capacity: usize,
}
/// Handle to a connection's writer task
#[derive(Clone)]
struct ConnectionHandle {
tx: flume::Sender<SendTask>,
}
/// Task sent to writer task containing pre-encoded frame
struct SendTask {
msg_type: MessageType,
header: Bytes,
payload: Bytes,
on_error: Arc<dyn TransportErrorHandler>,
}
impl SendTask {
fn on_error(self, error: impl Into<String>) {
self.on_error
.on_error(self.header, self.payload, error.into());
}
}
impl UdsTransport {
/// Create a new UDS transport
pub fn new(
socket_path: PathBuf,
key: TransportKey,
local_address: WorkerAddress,
channel_capacity: usize,
) -> Self {
Self {
key,
socket_path,
local_address,
peers: Arc::new(DashMap::new()),
connections: Arc::new(DashMap::new()),
runtime: OnceLock::new(),
cancel_token: CancellationToken::new(),
shutdown_state: OnceLock::new(),
channel_capacity,
}
}
/// Get the socket path this transport is bound to
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
/// Optional: Pre-establish connection after registration
pub fn ensure_connected(&self, instance_id: crate::InstanceId) -> Result<()> {
self.get_or_create_connection(instance_id)?;
Ok(())
}
/// Get or create a connection to a peer (lazy initialization)
fn get_or_create_connection(&self, instance_id: crate::InstanceId) -> Result<ConnectionHandle> {
// Fast path: connection already exists and is alive
if let Some(handle) = self.connections.get(&instance_id) {
if !handle.tx.is_disconnected() {
return Ok(handle.clone());
}
// Stale — drop guard before mutating the map
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
}
let rt = self.runtime.get().ok_or(TransportError::NotStarted)?;
// Atomic check-and-insert via entry API
let handle = match self.connections.entry(instance_id) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
if !entry.get().tx.is_disconnected() {
entry.get().clone()
} else {
// Stale entry — replace in-place with a fresh connection
let handle = self.create_connection(instance_id, rt)?;
entry.insert(handle.clone());
handle
}
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
let handle = self.create_connection(instance_id, rt)?;
entry.insert(handle.clone());
handle
}
};
Ok(handle)
}
/// Create a new connection handle and spawn the writer task.
fn create_connection(
&self,
instance_id: crate::InstanceId,
rt: &tokio::runtime::Handle,
) -> Result<ConnectionHandle> {
let path = self
.peers
.get(&instance_id)
.ok_or(TransportError::PeerNotRegistered(instance_id))?
.value()
.clone();
let (tx, rx) = flume::bounded(self.channel_capacity);
let handle = ConnectionHandle { tx };
let cancel = self.cancel_token.clone();
let conns = Arc::clone(&self.connections);
debug!("Created new UDS connection to {} ({:?})", instance_id, path);
rt.spawn(connection_writer_task(path, instance_id, rx, conns, cancel));
Ok(handle)
}
}
impl Transport for UdsTransport {
fn key(&self) -> TransportKey {
self.key.clone()
}
fn address(&self) -> WorkerAddress {
self.local_address.clone()
}
fn register(&self, peer_info: PeerInfo) -> Result<(), TransportError> {
// Get endpoint from peer's address
let endpoint = peer_info
.worker_address()
.get_entry(&self.key)
.map_err(|_| TransportError::NoEndpoint)?
.ok_or(TransportError::NoEndpoint)?;
// Parse UDS endpoint (expected format: "uds:///path/to/socket" or "/path/to/socket")
let path = parse_uds_endpoint(&endpoint).map_err(|e| {
error!("Failed to parse UDS endpoint: {}", e);
TransportError::InvalidEndpoint
})?;
// Store peer path
self.peers.insert(peer_info.instance_id(), path.clone());
debug!("Registered peer {} at {:?}", peer_info.instance_id(), path);
Ok(())
}
#[inline]
fn send_message(
&self,
instance_id: crate::InstanceId,
header: Vec<u8>,
payload: Vec<u8>,
message_type: MessageType,
on_error: Arc<dyn TransportErrorHandler>,
) {
let header = Bytes::from(header);
let payload = Bytes::from(payload);
let send_msg = SendTask {
msg_type: message_type,
header,
payload,
on_error,
};
// Fast path: try to send on existing connection
let send_msg = match self.connections.get(&instance_id) {
Some(handle) => match handle.tx.try_send(send_msg) {
Ok(()) => return,
Err(flume::TrySendError::Full(send_msg)) => send_msg,
Err(flume::TrySendError::Disconnected(send_msg)) => {
// Drop the guard before mutating the map
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
// Fall through to slow path to create a fresh connection
send_msg
}
},
None => send_msg,
};
// Slow path: create new connection
let rt = match self.runtime.get() {
Some(rt) => rt,
None => {
send_msg.on_error("Transport not started");
return;
}
};
let handle = match self.get_or_create_connection(instance_id) {
Ok(h) => h,
Err(e) => {
send_msg.on_error(format!("Failed to create connection: {}", e));
return;
}
};
rt.spawn(async move {
if let Err(flume::SendError(send_msg)) = handle.tx.send_async(send_msg).await {
send_msg.on_error("Connection closed");
}
});
}
fn start(
&self,
_instance_id: crate::InstanceId,
channels: TransportAdapter,
rt: tokio::runtime::Handle,
) -> futures::future::BoxFuture<'_, anyhow::Result<()>> {
// Store runtime handle for use in send_message
self.runtime.set(rt.clone()).ok();
// Capture shutdown state from the adapter
self.shutdown_state
.set(channels.shutdown_state.clone())
.ok();
let socket_path = self.socket_path.clone();
let shutdown_state = channels.shutdown_state.clone();
Box::pin(async move {
struct DefaultErrorHandler;
impl TransportErrorHandler for DefaultErrorHandler {
fn on_error(&self, _header: Bytes, _payload: Bytes, error: String) {
warn!("UDS transport error: {}", error);
}
}
// Remove a stale socket file only when it is safe to do so.
if socket_path.exists() {
let is_socket = std::fs::metadata(&socket_path)
.map(|m| m.file_type().is_socket())
.unwrap_or(false);
if !is_socket {
anyhow::bail!(
"path {:?} exists and is not a Unix domain socket",
socket_path
);
}
// Probe liveness: a successful connect means a live listener owns it.
match tokio::time::timeout(
Duration::from_millis(100),
UnixStream::connect(&socket_path),
)
.await
{
Ok(Ok(_)) => {
anyhow::bail!(
"a live UDS listener is already running at {:?}",
socket_path
);
}
_ => {
// Stale (connection refused / timeout) — safe to unlink.
std::fs::remove_file(&socket_path).ok();
}
}
}
// Build and bind before spawning so that start() only returns Ok
// after the OS-level bind succeeds.
let uds_listener = UdsListener::builder()
.socket_path(socket_path.clone())
.adapter(channels)
.error_handler(Arc::new(DefaultErrorHandler))
.shutdown_state(shutdown_state)
.build()?;
let bound_listener = uds_listener.bind()?;
rt.spawn(async move {
if let Err(e) = bound_listener.serve().await {
error!("UDS listener error: {}", e);
}
});
info!("UDS transport started on {:?}", socket_path);
Ok(())
})
}
fn begin_drain(&self) {
if let Some(state) = self.shutdown_state.get() {
state.begin_drain();
}
}
fn shutdown(&self) {
info!("Shutting down UDS transport");
// Cancel the teardown token (Phase 3) to stop the listener and connection handlers
if let Some(state) = self.shutdown_state.get() {
state.teardown_token().cancel();
}
self.cancel_token.cancel();
// Clear connections
self.connections.clear();
}
fn check_health(
&self,
instance_id: crate::InstanceId,
timeout: Duration,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<(), HealthCheckError>> + Send + '_>,
> {
Box::pin(async move {
let connection_exists = self.connections.contains_key(&instance_id);
if let Some(handle) = self.connections.get(&instance_id) {
if !handle.tx.is_disconnected() {
return Ok(());
}
// Channel is disconnected — drop guard and remove stale entry
drop(handle);
self.connections
.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
}
// No existing connection or connection is dead - verify peer is reachable
let path = self
.peers
.get(&instance_id)
.ok_or(HealthCheckError::PeerNotRegistered)?
.value()
.clone();
// Try to connect (and immediately drop) to verify peer is reachable
match tokio::time::timeout(timeout, UnixStream::connect(&path)).await {
Ok(Ok(_stream)) => {
if connection_exists {
Ok(())
} else {
Err(HealthCheckError::NeverConnected)
}
}
Ok(Err(_)) => Err(HealthCheckError::ConnectionFailed),
Err(_) => Err(HealthCheckError::Timeout),
}
})
}
}
/// Connection writer task for UDS
///
/// Mirrors the TCP connection_writer_task. Cleanup (draining queued messages
/// and removing the stale map entry) always runs, even if the initial connect fails.
async fn connection_writer_task(
path: PathBuf,
instance_id: crate::InstanceId,
rx: flume::Receiver<SendTask>,
connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>>,
cancel_token: CancellationToken,
) -> Result<()> {
let result = connection_writer_inner(&path, instance_id, &rx, &cancel_token).await;
// Always drain queued messages and notify their error handlers.
while let Ok(msg) = rx.try_recv() {
msg.on_error("Connection closed");
}
// Drop the receiver so our sender half becomes disconnected, then remove
// the stale entry. The predicate ensures we only remove our own entry —
// a replacement connection's tx will still be connected.
drop(rx);
connections.remove_if(&instance_id, |_, h| h.tx.is_disconnected());
debug!("UDS connection to {} ({:?}) closed", instance_id, path);
result
}
/// Inner loop: connect and send frames until the channel closes or a write error occurs.
async fn connection_writer_inner(
path: &Path,
instance_id: crate::InstanceId,
rx: &flume::Receiver<SendTask>,
cancel_token: &CancellationToken,
) -> Result<()> {
debug!("Connecting to UDS {:?}", path);
let mut stream = tokio::select! {
_ = cancel_token.cancelled() => return Ok(()),
res = UnixStream::connect(path) => res.context("UDS connect failed")?,
};
debug!("Connected to UDS {:?}", path);
// Main send loop
loop {
let msg = tokio::select! {
_ = cancel_token.cancelled() => break,
res = rx.recv_async() => match res {
Ok(msg) => msg,
Err(_) => break,
},
};
if let Err(e) =
TcpFrameCodec::encode_frame(&mut stream, msg.msg_type, &msg.header, &msg.payload).await
{
error!("Write error to {} ({:?}): {}", instance_id, path, e);
msg.on_error(format!("Failed to write to UDS stream: {}", e));
break;
}
}
Ok(())
}
/// Parse a UDS endpoint string into a PathBuf
///
/// Accepts formats:
/// - "uds:///path/to/socket"
/// - "/path/to/socket"
fn parse_uds_endpoint(endpoint: &[u8]) -> Result<PathBuf> {
let endpoint_str = std::str::from_utf8(endpoint).context("endpoint is not valid UTF-8")?;
// Strip "uds://" prefix if present
let path_str = endpoint_str.strip_prefix("uds://").unwrap_or(endpoint_str);
if path_str.is_empty() {
anyhow::bail!("empty UDS socket path");
}
Ok(PathBuf::from(path_str))
}
/// Builder for UdsTransport
pub struct UdsTransportBuilder {
socket_path: Option<PathBuf>,
key: Option<TransportKey>,
channel_capacity: usize,
}
impl UdsTransportBuilder {
/// Create a new builder
pub fn new() -> Self {
Self {
socket_path: None,
key: None,
channel_capacity: 256,
}
}
/// Set the socket path
pub fn socket_path(mut self, path: impl Into<PathBuf>) -> Self {
self.socket_path = Some(path.into());
self
}
/// Set the transport key
pub fn key(mut self, key: TransportKey) -> Self {
self.key = Some(key);
self
}
/// Set the channel capacity for backpressure (default: 256)
pub fn channel_capacity(mut self, capacity: usize) -> Self {
self.channel_capacity = capacity;
self
}
/// Build the UdsTransport
pub fn build(self) -> Result<UdsTransport> {
let socket_path = self
.socket_path
.ok_or_else(|| anyhow::anyhow!("socket_path is required"))?;
let key = self.key.unwrap_or_else(|| TransportKey::from("uds"));
let local_endpoint = format!("uds://{}", socket_path.display());
let mut addr_builder = crate::address::WorkerAddressBuilder::new();
addr_builder.add_entry(key.clone(), local_endpoint.as_bytes().to_vec())?;
let local_address = addr_builder.build()?;
Ok(UdsTransport::new(
socket_path,
key,
local_address,
self.channel_capacity,
))
}
}
impl Default for UdsTransportBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::address::WorkerAddressBuilder;
use std::sync::atomic::{AtomicUsize, Ordering};
use velo_common::PeerInfo;
/// Error handler that discards errors (for tests that don't need to track them).
struct NullErrorHandler;
impl TransportErrorHandler for NullErrorHandler {
fn on_error(&self, _: Bytes, _: Bytes, _: String) {}
}
/// Error handler that counts errors (for tests that verify error routing).
struct TrackingErrorHandler {
count: AtomicUsize,
}
impl TrackingErrorHandler {
fn new() -> Self {
Self {
count: AtomicUsize::new(0),
}
}
fn error_count(&self) -> usize {
self.count.load(Ordering::SeqCst)
}
}
impl TransportErrorHandler for TrackingErrorHandler {
fn on_error(&self, _: Bytes, _: Bytes, _: String) {
self.count.fetch_add(1, Ordering::SeqCst);
}
}
/// Build a `PeerInfo` whose UDS endpoint points at `path`.
fn make_uds_peer(path: &Path) -> PeerInfo {
let instance_id = crate::InstanceId::new_v4();
let mut builder = WorkerAddressBuilder::new();
builder
.add_entry("uds", format!("uds://{}", path.display()).into_bytes())
.unwrap();
PeerInfo::new(instance_id, builder.build().unwrap())
}
/// Build a `UdsTransport` with its runtime set, bound to a temp socket path.
/// Returns `(transport, socket_path)`.
fn make_transport() -> (UdsTransport, PathBuf) {
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let socket_path = dir.join("test.sock");
let transport = UdsTransportBuilder::new()
.socket_path(&socket_path)
.build()
.unwrap();
transport
.runtime
.set(tokio::runtime::Handle::current())
.ok();
(transport, socket_path)
}
/// Insert a stale `ConnectionHandle` into the transport's connections map.
fn insert_stale_handle(transport: &UdsTransport, instance_id: crate::InstanceId) {
let (tx, _rx) = flume::bounded::<SendTask>(1);
// Drop _rx immediately so tx.is_disconnected() == true
transport
.connections
.insert(instance_id, ConnectionHandle { tx });
}
#[test]
fn test_parse_uds_endpoint() {
// With uds:// prefix
let path = parse_uds_endpoint(b"uds:///tmp/test.sock").unwrap();
assert_eq!(path, PathBuf::from("/tmp/test.sock"));
// Without prefix
let path = parse_uds_endpoint(b"/var/run/anvil.sock").unwrap();
assert_eq!(path, PathBuf::from("/var/run/anvil.sock"));
// Empty path
assert!(parse_uds_endpoint(b"").is_err());
}
#[test]
fn test_builder_requires_socket_path() {
let result = UdsTransportBuilder::new().build();
assert!(result.is_err());
}
#[test]
fn test_builder_with_socket_path() {
let result = UdsTransportBuilder::new()
.socket_path("/tmp/test.sock")
.build();
assert!(result.is_ok());
}
#[test]
fn test_builder_custom_key() {
let transport = UdsTransportBuilder::new()
.socket_path("/tmp/test.sock")
.key(TransportKey::from("custom-uds"))
.build()
.unwrap();
assert_eq!(transport.key(), TransportKey::from("custom-uds"));
}
#[test]
fn test_transport_socket_path() {
let transport = UdsTransportBuilder::new()
.socket_path("/tmp/test.sock")
.build()
.unwrap();
assert_eq!(transport.socket_path(), Path::new("/tmp/test.sock"));
}
#[tokio::test]
async fn test_get_or_create_connection_replaces_stale_handle() {
let (transport, _socket_path) = make_transport();
// Start a UDS listener that the transport can connect to
let dir = std::env::temp_dir().join(format!("uds-peer-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let peer_socket = dir.join("peer.sock");
let peer_listener = tokio::net::UnixListener::bind(&peer_socket).unwrap();
let peer = make_uds_peer(&peer_socket);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert a stale handle
insert_stale_handle(&transport, iid);
assert!(
transport
.connections
.get(&iid)
.unwrap()
.tx
.is_disconnected()
);
// get_or_create_connection should replace the stale handle with a live one
let handle = transport.get_or_create_connection(iid).unwrap();
assert!(!handle.tx.is_disconnected());
// The map entry should also be live
let entry = transport.connections.get(&iid).unwrap();
assert!(!entry.tx.is_disconnected());
// Cleanup
drop(peer_listener);
std::fs::remove_file(&peer_socket).ok();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_check_health_removes_stale_entry() {
let (transport, _socket_path) = make_transport();
// Start a UDS listener so the peer is "reachable"
let dir = std::env::temp_dir().join(format!("uds-peer-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let peer_socket = dir.join("peer.sock");
let _peer_listener = tokio::net::UnixListener::bind(&peer_socket).unwrap();
let peer = make_uds_peer(&peer_socket);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert stale handle — simulates a dead writer task
insert_stale_handle(&transport, iid);
assert!(transport.connections.contains_key(&iid));
// check_health should remove the stale entry and verify the peer is reachable
let result = transport.check_health(iid, Duration::from_secs(2)).await;
// Stale entry should be gone
assert!(!transport.connections.contains_key(&iid));
// Since there WAS a previous connection entry, check_health returns Ok
assert!(result.is_ok());
// Cleanup
std::fs::remove_file(&peer_socket).ok();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_writer_task_cleans_up_on_write_error() {
// Bind a UDS listener, accept once, then drop everything to cause a write error
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let socket_path = dir.join("writer-test.sock");
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
let iid = crate::InstanceId::new_v4();
let (tx, rx) = flume::bounded::<SendTask>(8);
let connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>> =
Arc::new(DashMap::new());
connections.insert(iid, ConnectionHandle { tx: tx.clone() });
let conns = Arc::clone(&connections);
let cancel = CancellationToken::new();
// Spawn the writer task
let writer = tokio::spawn(connection_writer_task(
socket_path.clone(),
iid,
rx,
conns,
cancel,
));
// Accept the connection, then immediately drop it + the listener
let (stream, _) = listener.accept().await.unwrap();
drop(stream);
drop(listener);
// Send a message — the writer should hit a broken-pipe error
tx.send(SendTask {
msg_type: MessageType::Message,
header: Bytes::from_static(b"hdr"),
payload: Bytes::from_static(b"pay"),
on_error: Arc::new(NullErrorHandler),
})
.unwrap();
// Wait for writer task to finish
let _ = writer.await;
// The writer should have removed the stale entry from the map
assert!(
!connections.contains_key(&iid),
"writer task should clean up its DashMap entry on write error"
);
// Cleanup
std::fs::remove_file(&socket_path).ok();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_send_message_does_not_fail_on_stale_handle() {
let (transport, _socket_path) = make_transport();
// Start a UDS listener that accepts connections (simulates a healthy peer)
let dir = std::env::temp_dir().join(format!("uds-peer-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let peer_socket = dir.join("peer.sock");
let peer_listener = tokio::net::UnixListener::bind(&peer_socket).unwrap();
let peer = make_uds_peer(&peer_socket);
let iid = peer.instance_id();
transport.register(peer).unwrap();
// Insert a stale handle
insert_stale_handle(&transport, iid);
// send_message should detect the stale handle and create a new one
let error_handler = Arc::new(TrackingErrorHandler::new());
transport.send_message(
iid,
b"test-header".to_vec(),
b"test-payload".to_vec(),
MessageType::Message,
error_handler.clone(),
);
// Accept the connection that the new writer task will establish
let (mut stream, _) = peer_listener.accept().await.unwrap();
// Read the framed message from the stream to confirm delivery
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 256];
let n = tokio::time::timeout(Duration::from_secs(2), stream.read(&mut buf))
.await
.expect("timed out waiting for data")
.expect("read error");
assert!(n > 0, "expected data from the writer task");
// No errors should have been reported
assert_eq!(
error_handler.error_count(),
0,
"send_message should retry on stale handle, not fail"
);
// The connections map should now contain a live handle
let entry = transport.connections.get(&iid).unwrap();
assert!(
!entry.tx.is_disconnected(),
"stale handle should have been replaced with a live one"
);
// Cleanup
std::fs::remove_file(&peer_socket).ok();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_double_bind_returns_err() {
use crate::transport::make_channels;
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let socket_path = dir.join("double-bind.sock");
let transport1 = UdsTransportBuilder::new()
.socket_path(&socket_path)
.build()
.unwrap();
let instance_id = crate::InstanceId::new_v4();
let (adapter1, _streams1) = make_channels();
let rt = tokio::runtime::Handle::current();
// First bind must succeed.
transport1
.start(instance_id, adapter1, rt.clone())
.await
.unwrap();
// Second transport on the same path must fail.
let transport2 = UdsTransportBuilder::new()
.socket_path(&socket_path)
.build()
.unwrap();
let (adapter2, _streams2) = make_channels();
let result = transport2.start(instance_id, adapter2, rt).await;
assert!(
result.is_err(),
"start() should return Err when a live listener already owns the socket"
);
// Cleanup
transport1.shutdown();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_begin_drain_activates_draining_flag() {
use crate::transport::make_channels;
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let socket_path = dir.join("drain-test.sock");
let transport = UdsTransportBuilder::new()
.socket_path(&socket_path)
.build()
.unwrap();
let instance_id = crate::InstanceId::new_v4();
let (adapter, _streams) = make_channels();
let rt = tokio::runtime::Handle::current();
transport.start(instance_id, adapter, rt).await.unwrap();
assert!(
!transport.shutdown_state.get().unwrap().is_draining(),
"should not be draining before begin_drain()"
);
transport.begin_drain();
assert!(
transport.shutdown_state.get().unwrap().is_draining(),
"should be draining after begin_drain()"
);
// Cleanup
transport.shutdown();
std::fs::remove_dir_all(&dir).ok();
}
#[tokio::test]
async fn test_writer_task_drains_on_connect_failure() {
// Use a socket path where nothing is listening so connect will fail.
let dir = std::env::temp_dir().join(format!("uds-test-{}", crate::InstanceId::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
let dead_socket = dir.join("dead.sock");
let iid = crate::InstanceId::new_v4();
let (tx, rx) = flume::bounded::<SendTask>(8);
let connections: Arc<DashMap<crate::InstanceId, ConnectionHandle>> =
Arc::new(DashMap::new());
connections.insert(iid, ConnectionHandle { tx: tx.clone() });
// Queue a message before the writer task starts
let error_handler = Arc::new(TrackingErrorHandler::new());
tx.send(SendTask {
msg_type: MessageType::Message,
header: Bytes::from_static(b"hdr"),
payload: Bytes::from_static(b"pay"),
on_error: error_handler.clone(),
})
.unwrap();
let conns = Arc::clone(&connections);
let cancel = CancellationToken::new();
let writer = tokio::spawn(connection_writer_task(dead_socket, iid, rx, conns, cancel));
let _ = writer.await;
assert_eq!(
error_handler.error_count(),
1,
"queued message should have its on_error called when connect fails"
);
assert!(
!connections.contains_key(&iid),
"writer task should clean up its DashMap entry on connect failure"
);
// Cleanup
std::fs::remove_dir_all(&dir).ok();
}
}
...@@ -24,6 +24,9 @@ use velo_transports::{ ...@@ -24,6 +24,9 @@ use velo_transports::{
tcp::{TcpTransport, TcpTransportBuilder}, tcp::{TcpTransport, TcpTransportBuilder},
}; };
#[cfg(unix)]
use velo_transports::uds::{UdsTransport, UdsTransportBuilder};
use std::sync::Once; use std::sync::Once;
use tracing_subscriber::FmtSubscriber; use tracing_subscriber::FmtSubscriber;
...@@ -252,6 +255,24 @@ impl TestTransportHandle<TcpTransport> { ...@@ -252,6 +255,24 @@ impl TestTransportHandle<TcpTransport> {
} }
} }
// UDS-specific convenience constructors
#[cfg(unix)]
impl TestTransportHandle<UdsTransport> {
/// Create a new UDS transport using a temp directory socket path
pub async fn new_uds() -> anyhow::Result<Self> {
Self::with_factory(|| {
let dir = std::env::temp_dir().join(format!(
"velo-uds-test-{}",
velo_transports::InstanceId::new_v4()
));
std::fs::create_dir_all(&dir)?;
let socket_path = dir.join("transport.sock");
UdsTransportBuilder::new().socket_path(&socket_path).build()
})
.await
}
}
// // UCX-specific convenience constructors // // UCX-specific convenience constructors
// #[cfg(feature = "ucx")] // #[cfg(feature = "ucx")]
// impl TestTransportHandle<UcxTransport> { // impl TestTransportHandle<UcxTransport> {
...@@ -416,6 +437,24 @@ impl TestCluster<TcpTransport> { ...@@ -416,6 +437,24 @@ impl TestCluster<TcpTransport> {
} }
} }
// UDS-specific convenience constructor
#[cfg(unix)]
impl TestCluster<UdsTransport> {
/// Create a new UDS test cluster with the specified number of transports
pub async fn new_uds(size: usize) -> anyhow::Result<Self> {
Self::with_factory(size, || {
let dir = std::env::temp_dir().join(format!(
"velo-uds-test-{}",
velo_transports::InstanceId::new_v4()
));
std::fs::create_dir_all(&dir)?;
let socket_path = dir.join("transport.sock");
UdsTransportBuilder::new().socket_path(&socket_path).build()
})
.await
}
}
// // HTTP-specific convenience constructor // // HTTP-specific convenience constructor
// #[cfg(feature = "http")] // #[cfg(feature = "http")]
// impl TestCluster<HttpTransport> { // impl TestCluster<HttpTransport> {
...@@ -535,6 +574,23 @@ impl TransportFactory for TcpFactory { ...@@ -535,6 +574,23 @@ impl TransportFactory for TcpFactory {
} }
} }
/// UDS transport factory
#[cfg(unix)]
pub struct UdsFactory;
#[cfg(unix)]
impl TransportFactory for UdsFactory {
type Transport = UdsTransport;
async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
TestTransportHandle::new_uds().await
}
async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
TestCluster::new_uds(size).await
}
}
// /// UCX transport factory // /// UCX transport factory
// #[cfg(feature = "ucx")] // #[cfg(feature = "ucx")]
// pub struct UcxFactory; // pub struct UcxFactory;
......
...@@ -359,7 +359,10 @@ async fn test_connection_writer_exits_on_teardown() { ...@@ -359,7 +359,10 @@ async fn test_connection_writer_exits_on_teardown() {
// Give writer tasks time to exit // Give writer tasks time to exit
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
// Sending should now fail (error handler gets invoked, not a panic) // Sending after shutdown: the cancel_token is already cancelled, so any new
// writer task returns immediately without connecting. The error handler must
// be invoked and the message must not arrive on handle_b.
handle_a.error_handler.clear();
handle_a.send( handle_a.send(
handle_b.instance_id, handle_b.instance_id,
b"should-fail".to_vec(), b"should-fail".to_vec(),
...@@ -367,11 +370,25 @@ async fn test_connection_writer_exits_on_teardown() { ...@@ -367,11 +370,25 @@ async fn test_connection_writer_exits_on_teardown() {
MessageType::Message, MessageType::Message,
); );
// Give time for async error path // Give time for the async error path to complete
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// The message either goes to error handler or is silently dropped // Error handler must have been invoked for the failed send
// (connection cleared during shutdown). Just verify no panic occurred. assert!(
handle_a.error_handler.error_count() >= 1,
"error handler should be invoked for post-shutdown send"
);
// The message must not have been delivered to handle_b
let not_delivered = timeout(
Duration::from_millis(100),
handle_b.streams.message_stream.recv_async(),
)
.await;
assert!(
not_delivered.is_err(),
"post-shutdown message must not arrive at handle_b"
);
handle_b.streams.shutdown_state.teardown_token().cancel(); handle_b.streams.shutdown_state.teardown_token().cancel();
} }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for UDS transport
#![cfg(unix)]
mod common;
use common::{UdsFactory, scenarios};
#[tokio::test]
async fn test_single_message_round_trip() {
scenarios::single_message_round_trip::<UdsFactory>().await;
}
#[tokio::test]
async fn test_bidirectional_messaging() {
scenarios::bidirectional_messaging::<UdsFactory>().await;
}
#[tokio::test]
async fn test_multiple_messages_same_connection() {
scenarios::multiple_messages_same_connection::<UdsFactory>().await;
}
#[tokio::test]
async fn test_response_message_type() {
scenarios::response_message_type::<UdsFactory>().await;
}
#[tokio::test]
async fn test_event_message_type() {
scenarios::event_message_type::<UdsFactory>().await;
}
#[tokio::test]
async fn test_ack_message_type() {
scenarios::ack_message_type::<UdsFactory>().await;
}
#[tokio::test]
async fn test_mixed_message_types() {
scenarios::mixed_message_types::<UdsFactory>().await;
}
#[tokio::test]
async fn test_large_payload() {
scenarios::large_payload::<UdsFactory>().await;
}
#[tokio::test]
async fn test_empty_header_and_payload() {
scenarios::empty_header_and_payload::<UdsFactory>().await;
}
#[tokio::test]
async fn test_cluster_mesh_communication() {
scenarios::cluster_mesh_communication::<UdsFactory>().await;
}
#[tokio::test]
async fn test_concurrent_senders() {
scenarios::concurrent_senders::<UdsFactory>().await;
}
#[tokio::test]
async fn test_send_to_unregistered_peer() {
scenarios::send_to_unregistered_peer::<UdsFactory>().await;
}
#[tokio::test]
async fn test_connection_reuse() {
scenarios::connection_reuse::<UdsFactory>().await;
}
#[tokio::test]
async fn test_graceful_shutdown() {
scenarios::graceful_shutdown::<UdsFactory>().await;
}
#[tokio::test]
async fn test_high_throughput() {
scenarios::high_throughput::<UdsFactory>().await;
}
#[tokio::test]
async fn test_zero_copy_efficiency() {
scenarios::zero_copy_efficiency::<UdsFactory>().await;
}
#[tokio::test]
async fn test_drain_rejects_messages() {
scenarios::drain_rejects_messages::<UdsFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_responses() {
scenarios::drain_accepts_responses::<UdsFactory>().await;
}
#[tokio::test]
async fn test_drain_accepts_events() {
scenarios::drain_accepts_events::<UdsFactory>().await;
}
#[tokio::test]
async fn test_health_during_drain() {
scenarios::health_during_drain::<UdsFactory>().await;
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for UDS graceful shutdown
//!
//! These tests verify the 3-phase shutdown behavior:
//! 1. Gate: new Message frames are rejected with ShuttingDown
//! 2. Drain: in-flight work completes, responses/events still flow
//! 3. Teardown: listener and writer tasks exit
#![cfg(unix)]
mod common;
use bytes::Bytes;
use std::time::Duration;
use tokio::time::{sleep, timeout};
use velo_transports::tcp::TcpFrameCodec;
use velo_transports::uds::UdsTransport;
use velo_transports::{MessageType, Transport};
use common::TestTransportHandle;
/// Get the socket path from a UDS transport by parsing its WorkerAddress.
fn get_socket_path(handle: &TestTransportHandle<UdsTransport>) -> std::path::PathBuf {
let addr = handle.transport.address();
let key = handle.transport.key();
let endpoint = addr.get_entry(&key).unwrap().unwrap();
let s = std::str::from_utf8(&endpoint).unwrap();
let s = s.strip_prefix("uds://").unwrap_or(s);
std::path::PathBuf::from(s)
}
/// Helper: connect a raw UDS client to the transport's socket and send a frame.
async fn connect_and_send_frame(
socket_path: &std::path::Path,
msg_type: MessageType,
header: &[u8],
payload: &[u8],
) -> tokio::net::UnixStream {
let mut stream = tokio::net::UnixStream::connect(socket_path).await.unwrap();
TcpFrameCodec::encode_frame(&mut stream, msg_type, header, payload)
.await
.unwrap();
stream
}
/// Helper: read one frame from a raw UDS stream.
async fn read_one_frame(stream: &mut tokio::net::UnixStream) -> (MessageType, Bytes, Bytes) {
use futures::StreamExt;
use tokio_util::codec::Framed;
let mut framed = Framed::new(stream, TcpFrameCodec::new());
framed.next().await.unwrap().unwrap()
}
// --- Test: Drain rejects Message frames ---
#[tokio::test]
async fn test_uds_drain_rejects_messages() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
// Begin drain
handle.streams.shutdown_state.begin_drain();
// Give listener time to be ready
sleep(Duration::from_millis(50)).await;
// Connect and send a Message frame
let mut stream = connect_and_send_frame(
&socket_path,
MessageType::Message,
b"req-header",
b"req-payload",
)
.await;
// Should get ShuttingDown back
let (msg_type, header, payload) = read_one_frame(&mut stream).await;
assert_eq!(msg_type, MessageType::ShuttingDown);
assert_eq!(&header[..], b"req-header"); // Original header echoed back
assert_eq!(payload.len(), 0); // Empty payload
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: Drain accepts Response frames ---
#[tokio::test]
async fn test_uds_drain_accepts_responses() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
// Begin drain
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Connect and send a Response frame
connect_and_send_frame(
&socket_path,
MessageType::Response,
b"resp-header",
b"resp-payload",
)
.await;
// Should arrive on the response stream
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"resp-header");
assert_eq!(&payload[..], b"resp-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: Drain accepts Event frames ---
#[tokio::test]
async fn test_uds_drain_accepts_events() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
connect_and_send_frame(
&socket_path,
MessageType::Event,
b"evt-header",
b"evt-payload",
)
.await;
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.event_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"evt-header");
assert_eq!(&payload[..], b"evt-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: New connection during drain still accepts responses ---
#[tokio::test]
async fn test_uds_new_connection_during_drain() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
// Begin drain BEFORE connecting
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Establish a NEW connection after drain starts
connect_and_send_frame(
&socket_path,
MessageType::Response,
b"new-resp",
b"new-payload",
)
.await;
// Should arrive on the response stream
let (header, payload) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"new-resp");
assert_eq!(&payload[..], b"new-payload");
handle.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: Full graceful shutdown lifecycle ---
#[tokio::test]
async fn test_uds_graceful_shutdown_lifecycle() {
let handle = TestTransportHandle::new_uds().await.unwrap();
let socket_path = get_socket_path(&handle);
// Verify normal operation: send a message, receive it
connect_and_send_frame(
&socket_path,
MessageType::Message,
b"normal-msg",
b"normal-pay",
)
.await;
let (header, _payload) = timeout(
Duration::from_secs(2),
handle.streams.message_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"normal-msg");
// Acquire an InFlightGuard (simulate in-progress request)
let guard = handle.streams.shutdown_state.acquire();
assert_eq!(handle.streams.shutdown_state.in_flight_count(), 1);
// Begin drain (Phase 1)
handle.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Verify new messages are rejected
let mut stream =
connect_and_send_frame(&socket_path, MessageType::Message, b"reject-me", b"").await;
let (msg_type, _, _) = read_one_frame(&mut stream).await;
assert_eq!(msg_type, MessageType::ShuttingDown);
// Verify responses still flow
connect_and_send_frame(&socket_path, MessageType::Response, b"still-ok", b"data").await;
let (header, _) = timeout(
Duration::from_secs(2),
handle.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"still-ok");
// Spawn graceful_shutdown in background (will block on drain since guard is held)
let shutdown_state = handle.streams.shutdown_state.clone();
let shutdown_handle = tokio::spawn(async move {
// Phase 2: wait for drain
shutdown_state.wait_for_drain().await;
// Phase 3: teardown
shutdown_state.teardown_token().cancel();
});
// Verify shutdown hasn't completed yet (guard still held)
sleep(Duration::from_millis(100)).await;
assert!(!shutdown_handle.is_finished());
// Drop guard -> drain completes -> teardown fires
drop(guard);
timeout(Duration::from_secs(2), shutdown_handle)
.await
.expect("shutdown should complete")
.unwrap();
assert!(
handle
.streams
.shutdown_state
.teardown_token()
.is_cancelled()
);
}
// --- Test: Shutdown timeout forces teardown ---
#[tokio::test]
async fn test_uds_shutdown_timeout_forces_teardown() {
let handle = TestTransportHandle::new_uds().await.unwrap();
// Acquire guard and hold it
let _guard = handle.streams.shutdown_state.acquire();
let shutdown_state = handle.streams.shutdown_state.clone();
let shutdown_handle = tokio::spawn(async move {
shutdown_state.begin_drain();
// Phase 2: wait with short timeout
let _ =
tokio::time::timeout(Duration::from_millis(100), shutdown_state.wait_for_drain()).await;
// Phase 3: teardown (forced, guard still held)
shutdown_state.teardown_token().cancel();
});
timeout(Duration::from_secs(2), shutdown_handle)
.await
.expect("shutdown should complete via timeout")
.unwrap();
// Teardown should have fired even though guard is held
assert!(
handle
.streams
.shutdown_state
.teardown_token()
.is_cancelled()
);
// Guard is still held (not a problem — teardown was forced)
assert_eq!(handle.streams.shutdown_state.in_flight_count(), 1);
}
// --- Test: Outbound sends during drain ---
#[tokio::test]
async fn test_uds_outbound_sends_during_drain() {
// Create two transports and register them as peers
let handle_a = TestTransportHandle::new_uds().await.unwrap();
let handle_b = TestTransportHandle::new_uds().await.unwrap();
handle_a.register_peer(&handle_b).unwrap();
handle_b.register_peer(&handle_a).unwrap();
// Begin drain on transport A
handle_a.streams.shutdown_state.begin_drain();
sleep(Duration::from_millis(50)).await;
// Send a Response from A to B (outbound sends should work during drain)
handle_a.send(
handle_b.instance_id,
b"response-hdr".to_vec(),
b"response-pay".to_vec(),
MessageType::Response,
);
// B should receive the response
let (header, payload) = timeout(
Duration::from_secs(2),
handle_b.streams.response_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
assert_eq!(&header[..], b"response-hdr");
assert_eq!(&payload[..], b"response-pay");
handle_a.streams.shutdown_state.teardown_token().cancel();
handle_b.streams.shutdown_state.teardown_token().cancel();
}
// --- Test: Connection writer exits on teardown ---
#[tokio::test]
async fn test_uds_connection_writer_exits_on_teardown() {
let handle_a = TestTransportHandle::new_uds().await.unwrap();
let handle_b = TestTransportHandle::new_uds().await.unwrap();
handle_a.register_peer(&handle_b).unwrap();
// Send a message to establish the connection writer task
handle_a.send(
handle_b.instance_id,
b"setup".to_vec(),
b"data".to_vec(),
MessageType::Message,
);
// Wait for it to arrive
timeout(
Duration::from_secs(2),
handle_b.streams.message_stream.recv_async(),
)
.await
.expect("timeout")
.expect("recv");
// Shutdown transport A
handle_a.transport.shutdown();
// Give writer tasks time to exit
sleep(Duration::from_millis(200)).await;
// Sending after shutdown: the cancel_token is already cancelled, so any new
// writer task returns immediately without connecting. The error handler must
// be invoked and the message must not arrive on handle_b.
handle_a.error_handler.clear();
handle_a.send(
handle_b.instance_id,
b"should-fail".to_vec(),
b"data".to_vec(),
MessageType::Message,
);
// Give time for the async error path to complete
sleep(Duration::from_millis(100)).await;
// Error handler must have been invoked for the failed send
assert!(
handle_a.error_handler.error_count() >= 1,
"error handler should be invoked for post-shutdown send"
);
// The message must not have been delivered to handle_b
let not_delivered = timeout(
Duration::from_millis(100),
handle_b.streams.message_stream.recv_async(),
)
.await;
assert!(
not_delivered.is_err(),
"post-shutdown message must not arrive at handle_b"
);
handle_b.streams.shutdown_state.teardown_token().cancel();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment