Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
63d7c01c
Unverified
Commit
63d7c01c
authored
Mar 09, 2026
by
Ryan Olson
Committed by
GitHub
Mar 09, 2026
Browse files
feat: velo-backend (#6547)
Signed-off-by:
Ryan Olson
<
rolson@nvidia.com
>
parent
2cc92bfa
Changes
29
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
3773 additions
and
0 deletions
+3773
-0
lib/velo-transports/src/tcp/listener.rs
lib/velo-transports/src/tcp/listener.rs
+496
-0
lib/velo-transports/src/tcp/mod.rs
lib/velo-transports/src/tcp/mod.rs
+19
-0
lib/velo-transports/src/tcp/transport.rs
lib/velo-transports/src/tcp/transport.rs
+958
-0
lib/velo-transports/src/transport.rs
lib/velo-transports/src/transport.rs
+578
-0
lib/velo-transports/src/utils/mod.rs
lib/velo-transports/src/utils/mod.rs
+0
-0
lib/velo-transports/tests/common/mod.rs
lib/velo-transports/tests/common/mod.rs
+618
-0
lib/velo-transports/tests/common/scenarios.rs
lib/velo-transports/tests/common/scenarios.rs
+619
-0
lib/velo-transports/tests/tcp_integration.rs
lib/velo-transports/tests/tcp_integration.rs
+108
-0
lib/velo-transports/tests/tcp_shutdown.rs
lib/velo-transports/tests/tcp_shutdown.rs
+377
-0
No files found.
lib/velo-transports/src/tcp/listener.rs
0 → 100644
View file @
63d7c01c
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! High-performance TCP listener for ActiveMessage transport
//!
//! This module provides a TCP server that accepts incoming connections,
//! decodes framed messages using zero-copy techniques, and routes them
//! to the appropriate transport streams.
use
anyhow
::{
Context
,
Result
};
use
bytes
::
Bytes
;
use
futures
::
StreamExt
;
use
std
::
net
::
SocketAddr
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokio
::
net
::{
TcpListener
as
TokioTcpListener
,
TcpStream
};
use
tokio
::
runtime
::{
Handle
,
Runtime
};
use
tokio_util
::
codec
::
Framed
;
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
crate
::{
MessageType
,
ShutdownState
,
TransportAdapter
,
TransportErrorHandler
};
use
super
::
framing
::
TcpFrameCodec
;
/// Runtime configuration for the TCP listener
pub
enum
RuntimeConfig
{
/// Use an existing tokio runtime handle.
Handle
(
Handle
),
/// Use a provided tokio runtime.
Runtime
(
Arc
<
Runtime
>
),
/// Create a single-threaded runtime pinned to the specified CPU core (Linux only).
CpuPin
(
usize
),
}
/// High-performance TCP listener for ActiveMessage transport
///
/// This listener accepts incoming TCP connections and routes decoded frames
/// to the appropriate transport streams with zero-copy performance.
pub
struct
TcpListener
{
bind_addr
:
SocketAddr
,
adapter
:
TransportAdapter
,
error_handler
:
Arc
<
dyn
TransportErrorHandler
>
,
shutdown_state
:
ShutdownState
,
runtime_config
:
RuntimeConfig
,
listener
:
Option
<
std
::
net
::
TcpListener
>
,
}
impl
TcpListener
{
/// Create a new builder for TcpListener
pub
fn
builder
()
->
TcpListenerBuilder
{
TcpListenerBuilder
::
new
()
}
/// Start the listener and serve incoming connections
///
/// This method blocks or spawns based on the runtime configuration:
/// - For Handle/Runtime: spawns tasks and returns immediately
/// - For CpuPin: creates a pinned runtime and blocks until cancellation
pub
async
fn
serve
(
mut
self
)
->
Result
<
()
>
{
// Extract runtime config to avoid borrow issues
let
runtime_config
=
std
::
mem
::
replace
(
&
mut
self
.runtime_config
,
RuntimeConfig
::
Handle
(
Handle
::
current
()),
);
match
runtime_config
{
RuntimeConfig
::
Handle
(
handle
)
=>
{
handle
.spawn
(
async
move
{
if
let
Err
(
e
)
=
self
.run_server
()
.await
{
error!
(
"TCP listener error: {}"
,
e
);
}
});
Ok
(())
}
RuntimeConfig
::
Runtime
(
rt
)
=>
{
rt
.spawn
(
async
move
{
if
let
Err
(
e
)
=
self
.run_server
()
.await
{
error!
(
"TCP listener error: {}"
,
e
);
}
});
Ok
(())
}
RuntimeConfig
::
CpuPin
(
cpu_id
)
=>
{
let
rt
=
Self
::
create_pinned_runtime
(
cpu_id
)
.context
(
"Failed to create CPU-pinned runtime"
)
?
;
rt
.block_on
(
self
.run_server
())
}
}
}
/// Create a single-threaded runtime pinned to a specific CPU core
#[cfg(target_os
=
"linux"
)]
fn
create_pinned_runtime
(
cpu_id
:
usize
)
->
Result
<
Runtime
>
{
use
nix
::
sched
::{
CpuSet
,
sched_setaffinity
};
use
nix
::
unistd
::
Pid
;
tokio
::
runtime
::
Builder
::
new_current_thread
()
.enable_all
()
.thread_name
(
"tcp-listener-pinned"
)
.on_thread_start
(
move
||
{
let
mut
cpu_set
=
CpuSet
::
new
();
if
cpu_set
.set
(
cpu_id
)
.is_ok
()
{
if
let
Err
(
e
)
=
sched_setaffinity
(
Pid
::
from_raw
(
0
),
&
cpu_set
)
{
error!
(
"Failed to pin thread to CPU {}: {}"
,
cpu_id
,
e
);
}
else
{
debug!
(
"Successfully pinned TCP listener to CPU {}"
,
cpu_id
);
}
}
})
.build
()
.context
(
"Failed to build tokio runtime"
)
}
/// Create a single-threaded runtime without CPU pinning (non-Linux platforms)
#[cfg(not(target_os
=
"linux"
))]
fn
create_pinned_runtime
(
cpu_id
:
usize
)
->
Result
<
Runtime
>
{
warn!
(
"CPU pinning requested (CPU {}) but not supported on this platform"
,
cpu_id
);
tokio
::
runtime
::
Builder
::
new_current_thread
()
.enable_all
()
.thread_name
(
"tcp-listener"
)
.build
()
.context
(
"Failed to build tokio runtime"
)
}
/// Main server loop that accepts connections
async
fn
run_server
(
self
)
->
Result
<
()
>
{
// Use pre-bound listener if provided, otherwise bind to the address
let
listener
=
if
let
Some
(
std_listener
)
=
self
.listener
{
// Set non-blocking for tokio conversion
std_listener
.set_nonblocking
(
true
)
.context
(
"Failed to set listener to non-blocking"
)
?
;
TokioTcpListener
::
from_std
(
std_listener
)
.context
(
"Failed to convert std TcpListener to tokio TcpListener"
)
?
}
else
{
TokioTcpListener
::
bind
(
self
.bind_addr
)
.await
.context
(
format!
(
"Failed to bind TCP listener to {}"
,
self
.bind_addr
))
?
};
let
local_addr
=
listener
.local_addr
()
.context
(
"Failed to get local address"
)
?
;
info!
(
"TCP listener bound to {}"
,
local_addr
);
let
teardown_token
=
self
.shutdown_state
.teardown_token
()
.clone
();
loop
{
tokio
::
select!
{
accept_result
=
listener
.accept
()
=>
{
match
accept_result
{
Ok
((
stream
,
peer_addr
))
=>
{
debug!
(
"Accepted TCP connection from {}"
,
peer_addr
);
let
adapter
=
self
.adapter
.clone
();
let
error_handler
=
self
.error_handler
.clone
();
let
shutdown_state
=
self
.shutdown_state
.clone
();
tokio
::
spawn
(
async
move
{
if
let
Err
(
e
)
=
Self
::
handle_connection
(
stream
,
peer_addr
,
adapter
,
error_handler
,
shutdown_state
,
)
.await
{
warn!
(
"Error handling connection from {}: {}"
,
peer_addr
,
e
);
}
});
}
Err
(
e
)
=>
{
error!
(
"Failed to accept TCP connection: {}"
,
e
);
}
}
}
_
=
teardown_token
.cancelled
()
=>
{
info!
(
"TCP listener shutting down (teardown)"
);
break
;
}
}
}
Ok
(())
}
/// Handle a single TCP connection
async
fn
handle_connection
(
stream
:
TcpStream
,
peer_addr
:
SocketAddr
,
adapter
:
TransportAdapter
,
error_handler
:
Arc
<
dyn
TransportErrorHandler
>
,
shutdown_state
:
ShutdownState
,
)
->
Result
<
()
>
{
debug!
(
"Configuring connection from {}"
,
peer_addr
);
// Configure socket for high performance
if
let
Err
(
e
)
=
stream
.set_nodelay
(
true
)
{
warn!
(
"Failed to set TCP_NODELAY on {}: {}"
,
peer_addr
,
e
);
}
#[allow(deprecated)]
// Intentional: linger ensures clean socket shutdown
if
let
Err
(
e
)
=
stream
.set_linger
(
Some
(
Duration
::
from_secs
(
1
)))
{
warn!
(
"Failed to set linger on {}: {}"
,
peer_addr
,
e
);
}
// Set keep-alive to detect dead connections
let
keepalive
=
socket2
::
TcpKeepalive
::
new
()
.with_time
(
Duration
::
from_secs
(
60
))
.with_interval
(
Duration
::
from_secs
(
10
));
let
sock_ref
=
socket2
::
SockRef
::
from
(
&
stream
);
if
let
Err
(
e
)
=
sock_ref
.set_tcp_keepalive
(
&
keepalive
)
{
warn!
(
"Failed to set TCP keepalive on {}: {}"
,
peer_addr
,
e
);
}
// Set large receive buffer for high throughput
if
let
Err
(
e
)
=
sock_ref
.set_recv_buffer_size
(
1_048_576
)
{
warn!
(
"Failed to set receive buffer size on {}: {}"
,
peer_addr
,
e
);
}
// Create framed stream with zero-copy codec
let
mut
framed
=
Framed
::
new
(
stream
,
TcpFrameCodec
::
new
());
let
teardown_token
=
shutdown_state
.teardown_token
()
.clone
();
debug!
(
"Connection from {} ready for frames"
,
peer_addr
);
loop
{
tokio
::
select!
{
frame_result
=
framed
.next
()
=>
{
match
frame_result
{
Some
(
Ok
((
msg_type
,
header
,
payload
)))
=>
{
// During drain: reject new Message frames with ShuttingDown,
// but always pass through Response/Ack/Event frames.
if
shutdown_state
.is_draining
()
&&
msg_type
==
MessageType
::
Message
{
debug!
(
"Rejecting Message frame from {} during drain (sending ShuttingDown)"
,
peer_addr
);
// Echo original header back for correlation, empty payload
if
let
Err
(
e
)
=
TcpFrameCodec
::
encode_frame
(
framed
.get_mut
(),
MessageType
::
ShuttingDown
,
&
header
,
&
[],
)
.await
{
warn!
(
"Failed to send ShuttingDown frame to {}: {}"
,
peer_addr
,
e
);
}
continue
;
}
// Route frame to appropriate stream based on type
if
let
Err
(
e
)
=
Self
::
route_frame
(
msg_type
,
header
,
payload
,
&
adapter
,
&
error_handler
,
)
.await
{
warn!
(
"Failed to route {:?} frame from {}: {}"
,
msg_type
,
peer_addr
,
e
);
}
}
Some
(
Err
(
e
))
=>
{
error!
(
"Frame decode error from {}: {}"
,
peer_addr
,
e
);
break
;
}
None
=>
{
// Connection closed gracefully (FIN received)
debug!
(
"Connection from {} closed gracefully"
,
peer_addr
);
break
;
}
}
}
_
=
teardown_token
.cancelled
()
=>
{
debug!
(
"Connection handler for {} torn down"
,
peer_addr
);
break
;
}
}
}
Ok
(())
}
/// Route a decoded frame to the appropriate stream
///
/// This function performs zero-copy routing by transferring ownership of
/// the Bytes to the flume channel. On error, it invokes the error callback
/// with the original data (requiring a clone).
async
fn
route_frame
(
msg_type
:
MessageType
,
header
:
Bytes
,
payload
:
Bytes
,
adapter
:
&
TransportAdapter
,
error_handler
:
&
Arc
<
dyn
TransportErrorHandler
>
,
)
->
Result
<
()
>
{
let
sender
=
match
msg_type
{
MessageType
::
Message
=>
&
adapter
.message_stream
,
MessageType
::
Response
=>
&
adapter
.response_stream
,
MessageType
::
Ack
|
MessageType
::
Event
=>
&
adapter
.event_stream
,
MessageType
::
ShuttingDown
=>
{
// ShuttingDown is an outbound-only frame type; receiving it here
// means a remote peer rejected our request. Route to the response
// stream so higher layers can handle the rejection via correlation.
&
adapter
.response_stream
}
};
// Try to send with ownership transfer (zero-copy)
match
sender
.send_async
((
header
,
payload
))
.await
{
Ok
(
_
)
=>
Ok
(()),
Err
(
e
)
=>
{
// Send failed - invoke error callback with the data
error_handler
.on_error
(
e
.0.0
,
// header
e
.0.1
,
// payload
format!
(
"Failed to route {:?}"
,
msg_type
),
);
Err
(
anyhow
::
anyhow!
(
"Failed to send to stream"
))
}
}
}
}
/// Builder for TcpListener
pub
struct
TcpListenerBuilder
{
bind_addr
:
Option
<
SocketAddr
>
,
adapter
:
Option
<
TransportAdapter
>
,
error_handler
:
Option
<
Arc
<
dyn
TransportErrorHandler
>>
,
shutdown_state
:
Option
<
ShutdownState
>
,
runtime_config
:
Option
<
RuntimeConfig
>
,
listener
:
Option
<
std
::
net
::
TcpListener
>
,
}
impl
TcpListenerBuilder
{
/// Create a new builder
pub
fn
new
()
->
Self
{
Self
{
bind_addr
:
None
,
adapter
:
None
,
error_handler
:
None
,
shutdown_state
:
None
,
runtime_config
:
None
,
listener
:
None
,
}
}
/// Set the bind address
pub
fn
bind_addr
(
mut
self
,
addr
:
SocketAddr
)
->
Self
{
self
.bind_addr
=
Some
(
addr
);
self
}
/// Set the transport adapter
pub
fn
adapter
(
mut
self
,
adapter
:
TransportAdapter
)
->
Self
{
self
.adapter
=
Some
(
adapter
);
self
}
/// Set the error handler
pub
fn
error_handler
(
mut
self
,
handler
:
Arc
<
dyn
TransportErrorHandler
>
)
->
Self
{
self
.error_handler
=
Some
(
handler
);
self
}
/// Set the shutdown state for graceful drain coordination
pub
fn
shutdown_state
(
mut
self
,
state
:
ShutdownState
)
->
Self
{
self
.shutdown_state
=
Some
(
state
);
self
}
/// Use an existing tokio runtime handle
pub
fn
with_handle
(
mut
self
,
handle
:
Handle
)
->
Self
{
self
.runtime_config
=
Some
(
RuntimeConfig
::
Handle
(
handle
));
self
}
/// Use a provided tokio runtime
pub
fn
with_runtime
(
mut
self
,
runtime
:
Arc
<
Runtime
>
)
->
Self
{
self
.runtime_config
=
Some
(
RuntimeConfig
::
Runtime
(
runtime
));
self
}
/// Create a single-threaded runtime pinned to a specific CPU core
pub
fn
with_cpu_pin
(
mut
self
,
cpu_id
:
usize
)
->
Self
{
self
.runtime_config
=
Some
(
RuntimeConfig
::
CpuPin
(
cpu_id
));
self
}
/// Use a pre-bound TcpListener
///
/// This is useful for tests where you want to bind to port 0 and avoid port races.
/// When provided, the bind_addr should still be set (for logging/debugging purposes).
pub
fn
listener
(
mut
self
,
listener
:
Option
<
std
::
net
::
TcpListener
>
)
->
Self
{
self
.listener
=
listener
;
self
}
/// Build the TcpListener
pub
fn
build
(
self
)
->
Result
<
TcpListener
>
{
let
bind_addr
=
self
.bind_addr
.ok_or_else
(||
anyhow
::
anyhow!
(
"bind_addr is required"
))
?
;
let
adapter
=
self
.adapter
.ok_or_else
(||
anyhow
::
anyhow!
(
"adapter is required"
))
?
;
let
error_handler
=
self
.error_handler
.ok_or_else
(||
anyhow
::
anyhow!
(
"error_handler is required"
))
?
;
let
shutdown_state
=
self
.shutdown_state
.unwrap_or_default
();
let
runtime_config
=
self
.runtime_config
.unwrap_or_else
(||
RuntimeConfig
::
Handle
(
Handle
::
current
()));
Ok
(
TcpListener
{
bind_addr
,
adapter
,
error_handler
,
shutdown_state
,
runtime_config
,
listener
:
self
.listener
,
})
}
}
impl
Default
for
TcpListenerBuilder
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
transport
::
make_channels
;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
};
struct
TestErrorHandler
;
impl
TransportErrorHandler
for
TestErrorHandler
{
fn
on_error
(
&
self
,
_
header
:
Bytes
,
_
payload
:
Bytes
,
error
:
String
)
{
eprintln!
(
"Test error handler: {}"
,
error
);
}
}
#[test]
fn
test_builder_requires_fields
()
{
let
result
=
TcpListener
::
builder
()
.build
();
assert
!
(
result
.is_err
());
}
#[tokio::test]
async
fn
test_builder_with_all_fields
()
{
let
bind_addr
=
SocketAddr
::
new
(
IpAddr
::
V4
(
Ipv4Addr
::
LOCALHOST
),
0
);
let
(
adapter
,
_
streams
)
=
make_channels
();
let
error_handler
=
Arc
::
new
(
TestErrorHandler
);
let
result
=
TcpListener
::
builder
()
.bind_addr
(
bind_addr
)
.adapter
(
adapter
)
.error_handler
(
error_handler
)
.build
();
assert
!
(
result
.is_ok
());
}
#[test]
fn
test_builder_with_cpu_pin
()
{
let
bind_addr
=
SocketAddr
::
new
(
IpAddr
::
V4
(
Ipv4Addr
::
LOCALHOST
),
0
);
let
(
adapter
,
_
streams
)
=
make_channels
();
let
error_handler
=
Arc
::
new
(
TestErrorHandler
);
let
result
=
TcpListener
::
builder
()
.bind_addr
(
bind_addr
)
.adapter
(
adapter
)
.error_handler
(
error_handler
)
.with_cpu_pin
(
0
)
.build
();
assert
!
(
result
.is_ok
());
}
}
lib/velo-transports/src/tcp/mod.rs
0 → 100644
View file @
63d7c01c
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! TCP Transport Module
//!
//! This module provides a high-performance TCP transport implementation with:
//! - Zero-copy frame codec for minimal overhead
//! - CPU pinning support for predictable latency
//! - Frame type routing (Message, Response, Ack, Event)
//! - Graceful shutdown with proper FIN handling
//! - Keep-alive for dead connection detection
mod
framing
;
mod
listener
;
mod
transport
;
pub
use
framing
::
TcpFrameCodec
;
pub
use
listener
::{
RuntimeConfig
,
TcpListener
,
TcpListenerBuilder
};
pub
use
transport
::{
TcpTransport
,
TcpTransportBuilder
};
lib/velo-transports/src/tcp/transport.rs
0 → 100644
View file @
63d7c01c
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! High-performance TCP transport with single-threaded optimizations
//!
//! This implementation uses Rc+RefCell+LocalSet for maximum performance on a single CPU core.
//! All operations run on the same thread as the TCP listener for optimal cache locality.
use
anyhow
::{
Context
,
Result
};
use
bytes
::
Bytes
;
use
dashmap
::
DashMap
;
use
std
::
net
::{
SocketAddr
,
ToSocketAddrs
};
use
std
::
sync
::{
Arc
,
Mutex
,
OnceLock
};
use
std
::
time
::
Duration
;
use
tokio
::
net
::
TcpStream
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
crate
::
transport
::{
HealthCheckError
,
ShutdownState
,
TransportError
,
TransportErrorHandler
};
use
crate
::{
MessageType
,
PeerInfo
,
Transport
,
TransportAdapter
,
TransportKey
,
WorkerAddress
};
use
super
::
framing
::
TcpFrameCodec
;
use
super
::
listener
::
TcpListener
;
/// High-performance TCP transport with lock-free concurrent access
///
/// This transport uses `DashMap` for lock-free concurrent access to connection state.
/// Tasks are spawned using `tokio::spawn` for compatibility with the `Transport` trait.
/// For single-threaded performance, run the entire transport in a `LocalSet` context.
pub
struct
TcpTransport
{
// Identity (immutable, no wrapper needed)
key
:
TransportKey
,
bind_addr
:
SocketAddr
,
local_address
:
WorkerAddress
,
// Shared mutable state with DashMap (lock-free)
peers
:
Arc
<
DashMap
<
crate
::
InstanceId
,
SocketAddr
>>
,
connections
:
Arc
<
DashMap
<
crate
::
InstanceId
,
ConnectionHandle
>>
,
// Runtime handle for spawning tasks
runtime
:
OnceLock
<
tokio
::
runtime
::
Handle
>
,
// Shutdown coordination
cancel_token
:
CancellationToken
,
shutdown_state
:
OnceLock
<
ShutdownState
>
,
// Send channel capacity for backpressure
channel_capacity
:
usize
,
// Optional pre-bound listener (used for tests to avoid port races)
listener
:
Mutex
<
Option
<
std
::
net
::
TcpListener
>>
,
}
/// Handle to a connection's writer task
#[derive(Clone)]
struct
ConnectionHandle
{
tx
:
flume
::
Sender
<
SendTask
>
,
}
/// Task sent to writer task containing pre-encoded frame
struct
SendTask
{
msg_type
:
MessageType
,
header
:
Bytes
,
payload
:
Bytes
,
on_error
:
Arc
<
dyn
TransportErrorHandler
>
,
}
impl
SendTask
{
fn
on_error
(
self
,
error
:
impl
Into
<
String
>
)
{
self
.on_error
.on_error
(
self
.header
,
self
.payload
,
error
.into
());
}
}
impl
TcpTransport
{
/// Create a new TCP transport bound to `bind_addr` with the given transport key.
///
/// An optional pre-bound `listener` can be provided (useful for tests binding
/// to port 0). `channel_capacity` controls backpressure on per-connection
/// writer channels (default 256).
pub
fn
new
(
bind_addr
:
SocketAddr
,
key
:
TransportKey
,
local_address
:
WorkerAddress
,
channel_capacity
:
usize
,
listener
:
Option
<
std
::
net
::
TcpListener
>
,
)
->
Self
{
Self
{
key
,
bind_addr
,
local_address
,
peers
:
Arc
::
new
(
DashMap
::
new
()),
connections
:
Arc
::
new
(
DashMap
::
new
()),
runtime
:
OnceLock
::
new
(),
cancel_token
:
CancellationToken
::
new
(),
shutdown_state
:
OnceLock
::
new
(),
channel_capacity
,
listener
:
Mutex
::
new
(
listener
),
}
}
/// Optional: Pre-establish connection after registration
///
/// This can be called after `register()` to eagerly establish the TCP connection
/// instead of waiting for the first `send_message()` call.
pub
fn
ensure_connected
(
&
self
,
instance_id
:
crate
::
InstanceId
)
->
Result
<
()
>
{
self
.get_or_create_connection
(
instance_id
)
?
;
Ok
(())
}
/// Get or create a connection to a peer (lazy initialization)
fn
get_or_create_connection
(
&
self
,
instance_id
:
crate
::
InstanceId
)
->
Result
<
ConnectionHandle
>
{
// Fast path: connection already exists and is alive
if
let
Some
(
handle
)
=
self
.connections
.get
(
&
instance_id
)
{
if
!
handle
.tx
.is_disconnected
()
{
return
Ok
(
handle
.clone
());
}
// Stale — drop guard before mutating the map
drop
(
handle
);
self
.connections
.remove_if
(
&
instance_id
,
|
_
,
h
|
h
.tx
.is_disconnected
());
}
let
rt
=
self
.runtime
.get
()
.ok_or
(
TransportError
::
NotStarted
)
?
;
// Atomic check-and-insert via entry API
let
handle
=
match
self
.connections
.entry
(
instance_id
)
{
dashmap
::
mapref
::
entry
::
Entry
::
Occupied
(
mut
entry
)
=>
{
if
!
entry
.get
()
.tx
.is_disconnected
()
{
entry
.get
()
.clone
()
}
else
{
// Stale entry — replace in-place with a fresh connection
let
handle
=
self
.create_connection
(
instance_id
,
rt
)
?
;
entry
.insert
(
handle
.clone
());
handle
}
}
dashmap
::
mapref
::
entry
::
Entry
::
Vacant
(
entry
)
=>
{
let
handle
=
self
.create_connection
(
instance_id
,
rt
)
?
;
entry
.insert
(
handle
.clone
());
handle
}
};
Ok
(
handle
)
}
/// Create a new connection handle and spawn the writer task.
fn
create_connection
(
&
self
,
instance_id
:
crate
::
InstanceId
,
rt
:
&
tokio
::
runtime
::
Handle
,
)
->
Result
<
ConnectionHandle
>
{
let
addr
=
*
self
.peers
.get
(
&
instance_id
)
.ok_or
(
TransportError
::
PeerNotRegistered
(
instance_id
))
?
.value
();
let
(
tx
,
rx
)
=
flume
::
bounded
(
self
.channel_capacity
);
let
handle
=
ConnectionHandle
{
tx
};
let
cancel
=
self
.cancel_token
.clone
();
let
conns
=
Arc
::
clone
(
&
self
.connections
);
rt
.spawn
(
connection_writer_task
(
addr
,
instance_id
,
rx
,
conns
,
cancel
));
debug!
(
"Created new connection to {} ({})"
,
instance_id
,
addr
);
Ok
(
handle
)
}
}
impl
Transport
for
TcpTransport
{
fn
key
(
&
self
)
->
TransportKey
{
self
.key
.clone
()
}
fn
address
(
&
self
)
->
WorkerAddress
{
self
.local_address
.clone
()
}
fn
register
(
&
self
,
peer_info
:
PeerInfo
)
->
Result
<
(),
TransportError
>
{
// Get endpoint from peer's address
let
endpoint
=
peer_info
.worker_address
()
.get_entry
(
&
self
.key
)
.map_err
(|
_
|
TransportError
::
NoEndpoint
)
?
.ok_or
(
TransportError
::
NoEndpoint
)
?
;
// Parse TCP endpoint (expected format: "tcp://host:port" or "host:port")
let
addr
=
parse_tcp_endpoint
(
&
endpoint
)
.map_err
(|
e
|
{
error!
(
"Failed to parse TCP endpoint: {}"
,
e
);
TransportError
::
InvalidEndpoint
})
?
;
// Store peer address
self
.peers
.insert
(
peer_info
.instance_id
(),
addr
);
debug!
(
"Registered peer {} at {}"
,
peer_info
.instance_id
(),
addr
);
Ok
(())
}
#[inline]
fn
send_message
(
&
self
,
instance_id
:
crate
::
InstanceId
,
header
:
Vec
<
u8
>
,
payload
:
Vec
<
u8
>
,
message_type
:
MessageType
,
on_error
:
std
::
sync
::
Arc
<
dyn
TransportErrorHandler
>
,
)
{
// Convert to Bytes (one allocation each)
let
header
=
Bytes
::
from
(
header
);
let
payload
=
Bytes
::
from
(
payload
);
let
send_msg
=
SendTask
{
msg_type
:
message_type
,
header
,
payload
,
on_error
,
};
// Fast path: try to send on existing connection
let
send_msg
=
match
self
.connections
.get
(
&
instance_id
)
{
Some
(
handle
)
=>
match
handle
.tx
.try_send
(
send_msg
)
{
Ok
(())
=>
return
,
Err
(
flume
::
TrySendError
::
Full
(
send_msg
))
=>
send_msg
,
Err
(
flume
::
TrySendError
::
Disconnected
(
send_msg
))
=>
{
// Drop the guard before mutating the map
drop
(
handle
);
self
.connections
.remove_if
(
&
instance_id
,
|
_
,
h
|
h
.tx
.is_disconnected
());
// Fall through to slow path to create a fresh connection
send_msg
}
},
None
=>
send_msg
,
};
// Slow path: create new connection
let
rt
=
match
self
.runtime
.get
()
{
Some
(
rt
)
=>
rt
,
None
=>
{
send_msg
.on_error
(
"Transport not started"
);
return
;
}
};
let
handle
=
match
self
.get_or_create_connection
(
instance_id
)
{
Ok
(
h
)
=>
h
,
Err
(
e
)
=>
{
send_msg
.on_error
(
format!
(
"Failed to create connection: {}"
,
e
));
return
;
}
};
rt
.spawn
(
async
move
{
if
let
Err
(
flume
::
SendError
(
send_msg
))
=
handle
.tx
.send_async
(
send_msg
)
.await
{
send_msg
.on_error
(
"Connection closed"
);
}
});
}
fn
start
(
&
self
,
_
instance_id
:
crate
::
InstanceId
,
channels
:
TransportAdapter
,
rt
:
tokio
::
runtime
::
Handle
,
)
->
futures
::
future
::
BoxFuture
<
'_
,
anyhow
::
Result
<
()
>>
{
// Store runtime handle for use in send_message
self
.runtime
.set
(
rt
.clone
())
.ok
();
// Capture shutdown state from the adapter
self
.shutdown_state
.set
(
channels
.shutdown_state
.clone
())
.ok
();
let
bind_addr
=
self
.bind_addr
;
let
shutdown_state
=
channels
.shutdown_state
.clone
();
// Take ownership of the listener (if present) - we can only start once
let
listener
=
self
.listener
.lock
()
.expect
(
"Listener mutex poisoned"
)
.take
();
Box
::
pin
(
async
move
{
// Create error handler that routes to the transport error handler
struct
DefaultErrorHandler
;
impl
TransportErrorHandler
for
DefaultErrorHandler
{
fn
on_error
(
&
self
,
_
header
:
Bytes
,
_
payload
:
Bytes
,
error
:
String
)
{
warn!
(
"Transport error: {}"
,
error
);
}
}
// Start TCP listener
let
tcp_listener
=
TcpListener
::
builder
()
.bind_addr
(
bind_addr
)
.adapter
(
channels
)
.error_handler
(
std
::
sync
::
Arc
::
new
(
DefaultErrorHandler
))
.shutdown_state
(
shutdown_state
)
.listener
(
listener
)
.build
()
?
;
rt
.spawn
(
async
move
{
if
let
Err
(
e
)
=
tcp_listener
.serve
()
.await
{
error!
(
"TCP listener error: {}"
,
e
);
}
});
info!
(
"TCP transport started on {}"
,
bind_addr
);
Ok
(())
})
}
fn
begin_drain
(
&
self
)
{
// Per-frame gate in the listener handles drain — no-op here.
}
fn
shutdown
(
&
self
)
{
info!
(
"Shutting down TCP transport"
);
// Cancel the teardown token (Phase 3) to stop the listener and connection handlers
if
let
Some
(
state
)
=
self
.shutdown_state
.get
()
{
state
.teardown_token
()
.cancel
();
}
self
.cancel_token
.cancel
();
// Clear connections
self
.connections
.clear
();
}
fn
check_health
(
&
self
,
instance_id
:
crate
::
InstanceId
,
timeout
:
Duration
,
)
->
std
::
pin
::
Pin
<
Box
<
dyn
std
::
future
::
Future
<
Output
=
Result
<
(),
HealthCheckError
>>
+
Send
+
'_
>
,
>
{
Box
::
pin
(
async
move
{
// Check if we have an existing connection
let
connection_exists
=
self
.connections
.contains_key
(
&
instance_id
);
if
let
Some
(
handle
)
=
self
.connections
.get
(
&
instance_id
)
{
// Check if the channel is still connected (socket is still live)
// If the writer task has exited (socket closed), the channel will be disconnected
if
!
handle
.tx
.is_disconnected
()
{
return
Ok
(());
// Connection is alive and healthy
}
// Channel is disconnected — drop guard and remove stale entry
drop
(
handle
);
self
.connections
.remove_if
(
&
instance_id
,
|
_
,
h
|
h
.tx
.is_disconnected
());
}
// No existing connection or connection is dead - verify peer is reachable
let
addr
=
*
self
.peers
.get
(
&
instance_id
)
.ok_or
(
HealthCheckError
::
PeerNotRegistered
)
?
.value
();
// Try to connect (and immediately drop) to verify peer is reachable
match
tokio
::
time
::
timeout
(
timeout
,
TcpStream
::
connect
(
addr
))
.await
{
Ok
(
Ok
(
_
stream
))
=>
{
// Connection successful, drop immediately
// If we never had a connection before, report NeverConnected
// If we had one before that failed, report Ok (peer is reachable now)
if
connection_exists
{
Ok
(())
}
else
{
Err
(
HealthCheckError
::
NeverConnected
)
}
}
Ok
(
Err
(
_
))
=>
Err
(
HealthCheckError
::
ConnectionFailed
),
Err
(
_
)
=>
Err
(
HealthCheckError
::
Timeout
),
}
})
}
}
/// Connection writer task
///
/// This task runs on the LocalSet and handles writing framed bytes to the TCP stream.
/// It receives pre-encoded frames via a flume channel and writes them to the socket.
///
/// Cleanup (draining queued messages and removing the stale map entry) always runs,
/// even if the initial TCP connect fails.
async
fn
connection_writer_task
(
addr
:
SocketAddr
,
instance_id
:
crate
::
InstanceId
,
rx
:
flume
::
Receiver
<
SendTask
>
,
connections
:
Arc
<
DashMap
<
crate
::
InstanceId
,
ConnectionHandle
>>
,
_
cancel_token
:
CancellationToken
,
)
->
Result
<
()
>
{
let
result
=
connection_writer_inner
(
addr
,
instance_id
,
&
rx
)
.await
;
// Always drain queued messages and notify their error handlers.
//
// TODO: There is a tiny race between the drain finishing and `drop(rx)`:
// a sender on another thread could `try_send` successfully in that window,
// and the message would be silently dropped when rx is destroyed. Closing
// this fully would require swapping the map entry with a "poisoned" handle
// (a disconnected tx) before draining, so fast-path senders see a failure
// instead. Not worth the complexity today — at most one message is affected,
// and async senders already get `SendError` once rx is dropped.
while
let
Ok
(
msg
)
=
rx
.try_recv
()
{
msg
.on_error
(
"Connection closed"
);
}
// Drop the receiver so our sender half becomes disconnected, then remove
// the stale entry. The predicate ensures we only remove our own entry —
// a replacement connection's tx will still be connected.
drop
(
rx
);
connections
.remove_if
(
&
instance_id
,
|
_
,
h
|
h
.tx
.is_disconnected
());
debug!
(
"Connection to {} ({}) closed"
,
instance_id
,
addr
);
result
}
/// Inner loop: connect, configure the socket, and send frames until the channel
/// closes or a write error occurs.
async
fn
connection_writer_inner
(
addr
:
SocketAddr
,
instance_id
:
crate
::
InstanceId
,
rx
:
&
flume
::
Receiver
<
SendTask
>
,
)
->
Result
<
()
>
{
debug!
(
"Connecting to {}"
,
addr
);
let
mut
stream
=
TcpStream
::
connect
(
addr
)
.await
.context
(
"connect failed"
)
?
;
if
let
Err
(
e
)
=
stream
.set_nodelay
(
true
)
{
warn!
(
"Failed to set TCP_NODELAY: {}"
,
e
);
}
let
sock
=
socket2
::
SockRef
::
from
(
&
stream
);
if
let
Err
(
e
)
=
sock
.set_tcp_keepalive
(
&
socket2
::
TcpKeepalive
::
new
()
.with_time
(
Duration
::
from_secs
(
60
))
.with_interval
(
Duration
::
from_secs
(
10
)),
)
{
warn!
(
"Failed to set keepalive: {}"
,
e
);
}
if
let
Err
(
e
)
=
sock
.set_send_buffer_size
(
1_048_576
)
{
warn!
(
"Failed to set send buffer size: {}"
,
e
);
}
debug!
(
"Connected to {}"
,
addr
);
while
let
Ok
(
msg
)
=
rx
.recv_async
()
.await
{
if
let
Err
(
e
)
=
TcpFrameCodec
::
encode_frame
(
&
mut
stream
,
msg
.msg_type
,
&
msg
.header
,
&
msg
.payload
)
.await
{
error!
(
"Write error to {} ({}): {}"
,
instance_id
,
addr
,
e
);
msg
.on_error
(
format!
(
"Failed to write to stream: {}"
,
e
));
break
;
}
}
Ok
(())
}
/// Parse a TCP endpoint string into a SocketAddr
///
/// Accepts formats:
/// - "tcp://host:port"
/// - "host:port"
fn
parse_tcp_endpoint
(
endpoint
:
&
[
u8
])
->
Result
<
SocketAddr
>
{
let
endpoint_str
=
std
::
str
::
from_utf8
(
endpoint
)
.context
(
"endpoint is not valid UTF-8"
)
?
;
// Strip "tcp://" prefix if present
let
addr_str
=
endpoint_str
.strip_prefix
(
"tcp://"
)
.unwrap_or
(
endpoint_str
);
// Parse as socket address
let
mut
addrs
=
addr_str
.to_socket_addrs
()
.context
(
"failed to parse socket address"
)
?
;
addrs
.next
()
.ok_or_else
(||
anyhow
::
anyhow!
(
"no addresses resolved"
))
}
/// Resolve a wildcard bind address to a routable address for advertisement.
///
/// When binding to 0.0.0.0 (IPv4 unspecified) or :: (IPv6 unspecified),
/// we need to advertise a routable address that peers can actually connect to.
///
/// For 0.0.0.0, we use 127.0.0.1 (localhost) which works for same-machine communication.
/// For ::, we use ::1 (IPv6 localhost).
///
/// In a production multi-node deployment, this should be replaced with actual
/// network interface discovery or explicit configuration.
fn
resolve_advertise_address
(
bind_addr
:
SocketAddr
)
->
SocketAddr
{
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
Ipv6Addr
};
match
bind_addr
.ip
()
{
IpAddr
::
V4
(
ip
)
if
ip
.is_unspecified
()
=>
{
// 0.0.0.0 -> 127.0.0.1 for local testing
SocketAddr
::
new
(
IpAddr
::
V4
(
Ipv4Addr
::
LOCALHOST
),
bind_addr
.port
())
}
IpAddr
::
V6
(
ip
)
if
ip
.is_unspecified
()
=>
{
// :: -> ::1 for local testing
SocketAddr
::
new
(
IpAddr
::
V6
(
Ipv6Addr
::
LOCALHOST
),
bind_addr
.port
())
}
_
=>
{
// Already a specific address, use as-is
bind_addr
}
}
}
/// Builder for TcpTransport
pub
struct
TcpTransportBuilder
{
bind_addr
:
Option
<
SocketAddr
>
,
key
:
Option
<
TransportKey
>
,
channel_capacity
:
usize
,
listener
:
Option
<
std
::
net
::
TcpListener
>
,
}
impl
TcpTransportBuilder
{
/// Create a new builder
pub
fn
new
()
->
Self
{
Self
{
bind_addr
:
None
,
key
:
None
,
channel_capacity
:
256
,
listener
:
None
,
}
}
/// Set the bind address
pub
fn
bind_addr
(
mut
self
,
addr
:
SocketAddr
)
->
Self
{
self
.bind_addr
=
Some
(
addr
);
self
}
/// Set the transport key
pub
fn
key
(
mut
self
,
key
:
TransportKey
)
->
Self
{
self
.key
=
Some
(
key
);
self
}
/// Set the channel capacity for backpressure (default: 256)
pub
fn
channel_capacity
(
mut
self
,
capacity
:
usize
)
->
Self
{
self
.channel_capacity
=
capacity
;
self
}
/// Use a pre-bound TcpListener instead of binding to a specific address
///
/// This is useful for tests where you want to bind to port 0 and get an OS-assigned
/// port without creating a race condition between binding and starting the transport.
///
/// Note: This is mutually exclusive with `bind_addr()`. Using both will result in an error.
pub
fn
from_listener
(
mut
self
,
listener
:
std
::
net
::
TcpListener
)
->
Result
<
Self
>
{
// Validate mutual exclusivity: can't use both bind_addr() and from_listener()
if
self
.bind_addr
.is_some
()
{
anyhow
::
bail!
(
"Cannot use both bind_addr() and from_listener() - they are mutually exclusive"
);
}
let
addr
=
listener
.local_addr
()
.context
(
"Failed to get local address from listener"
)
?
;
self
.bind_addr
=
Some
(
addr
);
self
.listener
=
Some
(
listener
);
Ok
(
self
)
}
/// Build the TcpTransport
pub
fn
build
(
self
)
->
Result
<
TcpTransport
>
{
let
bind_addr
=
self
.bind_addr
.ok_or_else
(||
anyhow
::
anyhow!
(
"bind_addr is required"
))
?
;
let
key
=
self
.key
.unwrap_or_else
(||
TransportKey
::
from
(
"tcp"
));
// Resolve advertise address (handle 0.0.0.0 -> 127.0.0.1 for local testing)
let
advertise_addr
=
resolve_advertise_address
(
bind_addr
);
let
local_endpoint
=
format!
(
"tcp://{}"
,
advertise_addr
);
let
mut
addr_builder
=
crate
::
address
::
WorkerAddressBuilder
::
new
();
addr_builder
.add_entry
(
key
.clone
(),
local_endpoint
.as_bytes
()
.to_vec
())
?
;
let
local_address
=
addr_builder
.build
()
?
;
Ok
(
TcpTransport
::
new
(
bind_addr
,
key
,
local_address
,
self
.channel_capacity
,
self
.listener
,
))
}
}
impl
Default
for
TcpTransportBuilder
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
address
::
WorkerAddressBuilder
;
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
};
use
velo_common
::
PeerInfo
;
/// Error handler that discards errors (for tests that don't need to track them).
struct
NullErrorHandler
;
impl
TransportErrorHandler
for
NullErrorHandler
{
fn
on_error
(
&
self
,
_
:
Bytes
,
_
:
Bytes
,
_
:
String
)
{}
}
/// Error handler that counts errors (for tests that verify error routing).
struct
TrackingErrorHandler
{
count
:
AtomicUsize
,
}
impl
TrackingErrorHandler
{
fn
new
()
->
Self
{
Self
{
count
:
AtomicUsize
::
new
(
0
),
}
}
fn
error_count
(
&
self
)
->
usize
{
self
.count
.load
(
Ordering
::
SeqCst
)
}
}
impl
TransportErrorHandler
for
TrackingErrorHandler
{
fn
on_error
(
&
self
,
_
:
Bytes
,
_
:
Bytes
,
_
:
String
)
{
self
.count
.fetch_add
(
1
,
Ordering
::
SeqCst
);
}
}
/// Build a `PeerInfo` whose TCP endpoint points at `addr`.
fn
make_tcp_peer
(
addr
:
SocketAddr
)
->
PeerInfo
{
let
instance_id
=
crate
::
InstanceId
::
new_v4
();
let
mut
builder
=
WorkerAddressBuilder
::
new
();
builder
.add_entry
(
"tcp"
,
format!
(
"tcp://{}"
,
addr
)
.into_bytes
())
.unwrap
();
PeerInfo
::
new
(
instance_id
,
builder
.build
()
.unwrap
())
}
/// Build a `TcpTransport` with its runtime set, bound to a real listener.
/// Returns `(transport, listener_addr)`.
fn
make_transport
()
->
(
TcpTransport
,
SocketAddr
)
{
let
listener
=
std
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
.unwrap
();
let
addr
=
listener
.local_addr
()
.unwrap
();
let
transport
=
TcpTransportBuilder
::
new
()
.from_listener
(
listener
)
.unwrap
()
.build
()
.unwrap
();
// Set the runtime handle so `get_or_create_connection` can spawn tasks.
transport
.runtime
.set
(
tokio
::
runtime
::
Handle
::
current
())
.ok
();
(
transport
,
addr
)
}
/// Insert a stale `ConnectionHandle` into the transport's connections map.
/// A "stale" handle is one whose receiver has been dropped.
fn
insert_stale_handle
(
transport
:
&
TcpTransport
,
instance_id
:
crate
::
InstanceId
)
{
let
(
tx
,
_
rx
)
=
flume
::
bounded
::
<
SendTask
>
(
1
);
// Drop _rx immediately so tx.is_disconnected() == true
transport
.connections
.insert
(
instance_id
,
ConnectionHandle
{
tx
});
}
#[test]
fn
test_parse_tcp_endpoint
()
{
// With tcp:// prefix
let
addr
=
parse_tcp_endpoint
(
b
"tcp://127.0.0.1:5555"
)
.unwrap
();
assert_eq!
(
addr
.port
(),
5555
);
// Without prefix
let
addr
=
parse_tcp_endpoint
(
b
"127.0.0.1:6666"
)
.unwrap
();
assert_eq!
(
addr
.port
(),
6666
);
// Invalid
assert
!
(
parse_tcp_endpoint
(
b
"invalid"
)
.is_err
());
}
#[test]
fn
test_builder_requires_bind_addr
()
{
let
result
=
TcpTransportBuilder
::
new
()
.build
();
assert
!
(
result
.is_err
());
}
#[test]
fn
test_builder_with_bind_addr
()
{
let
addr
=
"127.0.0.1:0"
.parse
()
.unwrap
();
let
result
=
TcpTransportBuilder
::
new
()
.bind_addr
(
addr
)
.build
();
assert
!
(
result
.is_ok
());
}
#[test]
fn
test_builder_with_listener
()
{
let
listener
=
std
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
.unwrap
();
let
result
=
TcpTransportBuilder
::
new
()
.from_listener
(
listener
);
assert
!
(
result
.is_ok
());
let
result
=
result
.unwrap
()
.build
();
assert
!
(
result
.is_ok
());
}
#[test]
fn
test_builder_bind_addr_and_listener_mutually_exclusive
()
{
let
addr
=
"127.0.0.1:0"
.parse
()
.unwrap
();
let
listener
=
std
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
.unwrap
();
let
result
=
TcpTransportBuilder
::
new
()
.bind_addr
(
addr
)
.from_listener
(
listener
);
assert
!
(
result
.is_err
());
let
err_msg
=
format!
(
"{}"
,
result
.err
()
.unwrap
());
assert
!
(
err_msg
.contains
(
"mutually exclusive"
));
}
#[test]
fn
test_resolve_advertise_address_ipv4_unspecified
()
{
use
std
::
net
::{
IpAddr
,
Ipv4Addr
};
// 0.0.0.0 should resolve to 127.0.0.1
let
bind_addr
:
SocketAddr
=
"0.0.0.0:12345"
.parse
()
.unwrap
();
let
resolved
=
resolve_advertise_address
(
bind_addr
);
assert_eq!
(
resolved
.ip
(),
IpAddr
::
V4
(
Ipv4Addr
::
LOCALHOST
));
assert_eq!
(
resolved
.port
(),
12345
);
// Already specific address should remain unchanged
let
specific
:
SocketAddr
=
"192.168.1.100:8080"
.parse
()
.unwrap
();
let
resolved
=
resolve_advertise_address
(
specific
);
assert_eq!
(
resolved
,
specific
);
}
#[test]
fn
test_resolve_advertise_address_ipv6_unspecified
()
{
use
std
::
net
::{
IpAddr
,
Ipv6Addr
};
// :: should resolve to ::1
let
bind_addr
:
SocketAddr
=
"[::]:12345"
.parse
()
.unwrap
();
let
resolved
=
resolve_advertise_address
(
bind_addr
);
assert_eq!
(
resolved
.ip
(),
IpAddr
::
V6
(
Ipv6Addr
::
LOCALHOST
));
assert_eq!
(
resolved
.port
(),
12345
);
// Already specific IPv6 address should remain unchanged
let
specific
:
SocketAddr
=
"[::1]:8080"
.parse
()
.unwrap
();
let
resolved
=
resolve_advertise_address
(
specific
);
assert_eq!
(
resolved
,
specific
);
}
#[tokio::test]
async
fn
test_get_or_create_connection_replaces_stale_handle
()
{
let
(
transport
,
_
our_addr
)
=
make_transport
();
// Start a listener that the transport can connect to
let
peer_listener
=
std
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
.unwrap
();
let
peer_addr
=
peer_listener
.local_addr
()
.unwrap
();
let
peer
=
make_tcp_peer
(
peer_addr
);
let
iid
=
peer
.instance_id
();
transport
.register
(
peer
)
.unwrap
();
// Insert a stale handle
insert_stale_handle
(
&
transport
,
iid
);
assert
!
(
transport
.connections
.get
(
&
iid
)
.unwrap
()
.tx
.is_disconnected
()
);
// get_or_create_connection should replace the stale handle with a live one
let
handle
=
transport
.get_or_create_connection
(
iid
)
.unwrap
();
assert
!
(
!
handle
.tx
.is_disconnected
());
// The map entry should also be live
let
entry
=
transport
.connections
.get
(
&
iid
)
.unwrap
();
assert
!
(
!
entry
.tx
.is_disconnected
());
}
#[tokio::test]
async
fn
test_check_health_removes_stale_entry
()
{
let
(
transport
,
_
our_addr
)
=
make_transport
();
// Start a listener so the peer is "reachable"
let
peer_listener
=
tokio
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
.await
.unwrap
();
let
peer_addr
=
peer_listener
.local_addr
()
.unwrap
();
let
peer
=
make_tcp_peer
(
peer_addr
);
let
iid
=
peer
.instance_id
();
transport
.register
(
peer
)
.unwrap
();
// Insert stale handle — simulates a dead writer task
insert_stale_handle
(
&
transport
,
iid
);
assert
!
(
transport
.connections
.contains_key
(
&
iid
));
// check_health should remove the stale entry and verify the peer is reachable
let
result
=
transport
.check_health
(
iid
,
Duration
::
from_secs
(
2
))
.await
;
// Stale entry should be gone
assert
!
(
!
transport
.connections
.contains_key
(
&
iid
));
// Since there WAS a previous connection entry, check_health returns Ok
// (the peer is reachable via our test listener)
assert
!
(
result
.is_ok
());
}
#[tokio::test]
async
fn
test_writer_task_cleans_up_on_write_error
()
{
// Bind a listener, accept once, then drop everything to cause a write error
let
listener
=
tokio
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
.await
.unwrap
();
let
addr
=
listener
.local_addr
()
.unwrap
();
let
iid
=
crate
::
InstanceId
::
new_v4
();
let
(
tx
,
rx
)
=
flume
::
bounded
::
<
SendTask
>
(
8
);
let
connections
:
Arc
<
DashMap
<
crate
::
InstanceId
,
ConnectionHandle
>>
=
Arc
::
new
(
DashMap
::
new
());
connections
.insert
(
iid
,
ConnectionHandle
{
tx
:
tx
.clone
()
});
let
conns
=
Arc
::
clone
(
&
connections
);
let
cancel
=
CancellationToken
::
new
();
// Spawn the writer task
let
writer
=
tokio
::
spawn
(
connection_writer_task
(
addr
,
iid
,
rx
,
conns
,
cancel
));
// Accept the connection, then immediately drop it + the listener
let
(
stream
,
_
)
=
listener
.accept
()
.await
.unwrap
();
drop
(
stream
);
drop
(
listener
);
// Send a message — the writer should hit a broken-pipe error
tx
.send
(
SendTask
{
msg_type
:
MessageType
::
Message
,
header
:
Bytes
::
from_static
(
b
"hdr"
),
payload
:
Bytes
::
from_static
(
b
"pay"
),
on_error
:
Arc
::
new
(
NullErrorHandler
),
})
.unwrap
();
// Wait for writer task to finish
let
_
=
writer
.await
;
// The writer should have removed the stale entry from the map
assert
!
(
!
connections
.contains_key
(
&
iid
),
"writer task should clean up its DashMap entry on write error"
);
}
#[tokio::test]
async
fn
test_send_message_does_not_fail_on_stale_handle
()
{
let
(
transport
,
_
our_addr
)
=
make_transport
();
// Start a listener that accepts connections (simulates a healthy peer)
let
peer_listener
=
tokio
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
.await
.unwrap
();
let
peer_addr
=
peer_listener
.local_addr
()
.unwrap
();
let
peer
=
make_tcp_peer
(
peer_addr
);
let
iid
=
peer
.instance_id
();
transport
.register
(
peer
)
.unwrap
();
// Insert a stale handle
insert_stale_handle
(
&
transport
,
iid
);
// send_message should detect the stale handle and create a new one,
// NOT immediately call on_error
let
error_handler
=
Arc
::
new
(
TrackingErrorHandler
::
new
());
transport
.send_message
(
iid
,
b
"test-header"
.to_vec
(),
b
"test-payload"
.to_vec
(),
MessageType
::
Message
,
error_handler
.clone
(),
);
// Accept the connection that the new writer task will establish
let
(
mut
stream
,
_
)
=
peer_listener
.accept
()
.await
.unwrap
();
// Read the framed message from the stream to confirm delivery
use
tokio
::
io
::
AsyncReadExt
;
let
mut
buf
=
[
0u8
;
256
];
// Give the async writer a moment to flush the frame
let
n
=
tokio
::
time
::
timeout
(
Duration
::
from_secs
(
2
),
stream
.read
(
&
mut
buf
))
.await
.expect
(
"timed out waiting for data"
)
.expect
(
"read error"
);
assert
!
(
n
>
0
,
"expected data from the writer task"
);
// No errors should have been reported
assert_eq!
(
error_handler
.error_count
(),
0
,
"send_message should retry on stale handle, not fail"
);
// The connections map should now contain a live handle
let
entry
=
transport
.connections
.get
(
&
iid
)
.unwrap
();
assert
!
(
!
entry
.tx
.is_disconnected
(),
"stale handle should have been replaced with a live one"
);
}
#[tokio::test]
async
fn
test_writer_task_drains_on_connect_failure
()
{
// Use an address where nothing is listening so connect will fail.
// Binding then immediately dropping gives us a port that is guaranteed closed.
let
tmp
=
std
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
.unwrap
();
let
addr
=
tmp
.local_addr
()
.unwrap
();
drop
(
tmp
);
let
iid
=
crate
::
InstanceId
::
new_v4
();
let
(
tx
,
rx
)
=
flume
::
bounded
::
<
SendTask
>
(
8
);
let
connections
:
Arc
<
DashMap
<
crate
::
InstanceId
,
ConnectionHandle
>>
=
Arc
::
new
(
DashMap
::
new
());
connections
.insert
(
iid
,
ConnectionHandle
{
tx
:
tx
.clone
()
});
// Queue a message *before* the writer task even starts — this simulates
// the race between create_connection returning and connect completing.
let
error_handler
=
Arc
::
new
(
TrackingErrorHandler
::
new
());
tx
.send
(
SendTask
{
msg_type
:
MessageType
::
Message
,
header
:
Bytes
::
from_static
(
b
"hdr"
),
payload
:
Bytes
::
from_static
(
b
"pay"
),
on_error
:
error_handler
.clone
(),
})
.unwrap
();
let
conns
=
Arc
::
clone
(
&
connections
);
let
cancel
=
CancellationToken
::
new
();
let
writer
=
tokio
::
spawn
(
connection_writer_task
(
addr
,
iid
,
rx
,
conns
,
cancel
));
let
_
=
writer
.await
;
assert_eq!
(
error_handler
.error_count
(),
1
,
"queued message should have its on_error called when connect fails"
);
assert
!
(
!
connections
.contains_key
(
&
iid
),
"writer task should clean up its DashMap entry on connect failure"
);
}
}
lib/velo-transports/src/transport.rs
0 → 100644
View file @
63d7c01c
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
bytes
::
Bytes
;
use
futures
::
future
::
BoxFuture
;
use
crate
::{
InstanceId
,
PeerInfo
,
TransportKey
,
WorkerAddress
};
use
std
::
sync
::
atomic
::{
AtomicBool
,
AtomicUsize
,
Ordering
};
use
std
::{
sync
::
Arc
,
time
::
Duration
};
use
tokio
::
sync
::
Notify
;
use
tokio_util
::
sync
::
CancellationToken
;
/// Errors returned by individual [`Transport`] implementations.
#[derive(thiserror::Error,
Debug)]
pub
enum
TransportError
{
/// The peer's [`WorkerAddress`] does not contain an entry for this transport.
#[error(
"No endpoint found for transport"
)]
NoEndpoint
,
/// The endpoint string could not be parsed (malformed URL, invalid address).
#[error(
"Invalid endpoint format"
)]
InvalidEndpoint
,
/// The target peer was never registered with this transport.
#[error(
"Peer not registered: {0}"
)]
PeerNotRegistered
(
InstanceId
),
/// The transport has not been started yet (no runtime handle).
#[error(
"Transport not started"
)]
NotStarted
,
/// No responders available for the peer (e.g. NATS request with no subscriber).
#[error(
"No responders for peer"
)]
NoResponders
,
}
/// Error type specific to health check operations
#[derive(thiserror::Error,
Debug,
Clone,
PartialEq,
Eq)]
pub
enum
HealthCheckError
{
/// The peer was never registered with this transport.
#[error(
"Peer not registered with transport"
)]
PeerNotRegistered
,
/// The transport has not been started yet.
#[error(
"Transport not started"
)]
TransportNotStarted
,
/// The peer is registered but no connection has ever been established.
#[error(
"Connection never established to peer"
)]
NeverConnected
,
/// An existing connection is unhealthy or the peer is unreachable.
#[error(
"Connection failed or peer unreachable"
)]
ConnectionFailed
,
/// The health check exceeded the specified timeout.
#[error(
"Health check timed out"
)]
Timeout
,
}
/// Shared shutdown coordinator for graceful multi-phase shutdown.
///
/// **Phases**:
/// 1. **Gate** — `begin_drain()` flips the draining flag; transports reject new inbound requests.
/// 2. **Drain** — `wait_for_drain()` blocks until all in-flight guards are dropped.
/// 3. **Teardown** — `teardown_token().cancel()` kills listeners and writer tasks.
///
/// Hot-path cost: a single `AtomicBool::load(Relaxed)` per frame to check `is_draining()`.
#[derive(Clone)]
pub
struct
ShutdownState
{
inner
:
Arc
<
ShutdownStateInner
>
,
}
struct
ShutdownStateInner
{
draining
:
AtomicBool
,
in_flight
:
AtomicUsize
,
drain_complete
:
Notify
,
teardown_token
:
CancellationToken
,
}
impl
ShutdownState
{
/// Create a new shutdown state. Not draining, zero in-flight.
pub
fn
new
()
->
Self
{
Self
{
inner
:
Arc
::
new
(
ShutdownStateInner
{
draining
:
AtomicBool
::
new
(
false
),
in_flight
:
AtomicUsize
::
new
(
0
),
drain_complete
:
Notify
::
new
(),
teardown_token
:
CancellationToken
::
new
(),
}),
}
}
/// Returns `true` if drain has been initiated (Phase 1).
///
/// Uses `Relaxed` ordering — safe for the hot-path gate check because
/// the flag is monotonic (false → true, never reset).
#[inline]
pub
fn
is_draining
(
&
self
)
->
bool
{
self
.inner.draining
.load
(
Ordering
::
Relaxed
)
}
/// Begin Phase 1: flip the draining flag. Idempotent.
pub
fn
begin_drain
(
&
self
)
{
self
.inner.draining
.store
(
true
,
Ordering
::
Release
);
}
/// Acquire an in-flight guard. The guard increments the counter on creation
/// and decrements it on drop. Use this to track requests that are being processed.
///
/// Guards are still acquirable after `begin_drain()` — this is intentional
/// so that already-accepted work can be tracked.
pub
fn
acquire
(
&
self
)
->
InFlightGuard
{
self
.inner.in_flight
.fetch_add
(
1
,
Ordering
::
AcqRel
);
InFlightGuard
{
inner
:
self
.inner
.clone
(),
}
}
/// Current number of in-flight requests. Primarily for testing/debugging.
pub
fn
in_flight_count
(
&
self
)
->
usize
{
self
.inner.in_flight
.load
(
Ordering
::
Acquire
)
}
/// Wait until in-flight count reaches zero. Returns immediately if already zero.
pub
async
fn
wait_for_drain
(
&
self
)
{
loop
{
if
self
.inner.in_flight
.load
(
Ordering
::
Acquire
)
==
0
{
return
;
}
self
.inner.drain_complete
.notified
()
.await
;
}
}
/// Get the Phase 3 teardown token. Cancel this to kill listeners/writers.
pub
fn
teardown_token
(
&
self
)
->
&
CancellationToken
{
&
self
.inner.teardown_token
}
}
impl
Default
for
ShutdownState
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
/// RAII guard that decrements the in-flight counter on drop.
pub
struct
InFlightGuard
{
inner
:
Arc
<
ShutdownStateInner
>
,
}
impl
InFlightGuard
{
/// Explicitly complete this guard (equivalent to dropping it).
pub
fn
complete
(
self
)
{
// Drop impl handles the decrement
}
}
impl
Drop
for
InFlightGuard
{
fn
drop
(
&
mut
self
)
{
let
prev
=
self
.inner.in_flight
.fetch_sub
(
1
,
Ordering
::
AcqRel
);
// If we just decremented to 0, notify waiters
if
prev
==
1
{
self
.inner.drain_complete
.notify_waiters
();
}
}
}
/// Policy for how long to wait during the drain phase.
#[derive(Debug,
Clone)]
pub
enum
ShutdownPolicy
{
/// Wait indefinitely for all in-flight requests to complete.
WaitForever
,
/// Wait up to the given duration, then force teardown.
Timeout
(
Duration
),
}
/// Abstraction over a single message transport (TCP, HTTP, NATS, gRPC, UCX).
///
/// Implementations handle peer registration, message sending, listener lifecycle,
/// health checking, and graceful shutdown. The trait is object-safe so transports
/// can be stored as `Arc<dyn Transport>`.
pub
trait
Transport
:
Send
+
Sync
{
/// Unique key identifying this transport (e.g. `"tcp"`, `"grpc"`).
fn
key
(
&
self
)
->
TransportKey
;
/// The [`WorkerAddress`] fragment advertised by this transport.
fn
address
(
&
self
)
->
WorkerAddress
;
/// Register a remote peer, extracting its endpoint from [`PeerInfo`].
fn
register
(
&
self
,
peer_info
:
PeerInfo
)
->
Result
<
(),
TransportError
>
;
/// Sends an active message to the remote instance
fn
send_message
(
&
self
,
instance_id
:
InstanceId
,
header
:
Vec
<
u8
>
,
payload
:
Vec
<
u8
>
,
message_type
:
MessageType
,
on_error
:
Arc
<
dyn
TransportErrorHandler
>
,
);
/// Start the transport (bind listener, spawn tasks) for the given instance.
fn
start
(
&
self
,
instance_id
:
InstanceId
,
channels
:
TransportAdapter
,
rt
:
tokio
::
runtime
::
Handle
,
)
->
BoxFuture
<
'_
,
anyhow
::
Result
<
()
>>
;
/// Tear down the transport, cancelling all tasks and closing connections.
fn
shutdown
(
&
self
);
/// Begin draining: reject new inbound requests while allowing responses.
///
/// Default implementation is a no-op. Transports that need per-frame
/// gating (e.g., unsubscribing from NATS subjects) should override this.
fn
begin_drain
(
&
self
)
{}
/// Check if a registered peer is reachable and healthy
///
/// Returns Ok(()) if peer responds to health check within timeout.
/// Different transports implement this differently:
/// - NATS: request/reply to health subject
/// - TCP: check existing connection or attempt new connection
/// - HTTP: HEAD request to health endpoint
/// - UCX: endpoint status check
///
/// # Errors
/// - `PeerNotRegistered`: Peer was never registered with this transport
/// - `TransportNotStarted`: Transport hasn't been started yet
/// - `NeverConnected`: Peer is registered but no connection has been established
/// - `ConnectionFailed`: Connection exists/existed but is currently unhealthy or unreachable
/// - `Timeout`: Health check took longer than the specified timeout
fn
check_health
(
&
self
,
instance_id
:
InstanceId
,
timeout
:
Duration
,
)
->
std
::
pin
::
Pin
<
Box
<
dyn
std
::
future
::
Future
<
Output
=
Result
<
(),
HealthCheckError
>>
+
Send
+
'_
>
,
>
;
}
/// Callback trait invoked when a transport fails to deliver a message.
///
/// The original `header` and `payload` are returned so higher layers can
/// retry or log the failure.
pub
trait
TransportErrorHandler
:
Send
+
Sync
{
/// Called when message delivery fails. Receives the original data and error description.
fn
on_error
(
&
self
,
header
:
Bytes
,
payload
:
Bytes
,
error
:
String
);
}
/// Message type discriminator for routing frames to appropriate streams
#[repr(u8)]
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
enum
MessageType
{
#[allow(missing_docs)]
Message
=
0
,
#[allow(missing_docs)]
Response
=
1
,
#[allow(missing_docs)]
Ack
=
2
,
#[allow(missing_docs)]
Event
=
3
,
/// Sent back to a peer when we are draining and cannot accept new messages.
/// The original request header is echoed back for correlation.
ShuttingDown
=
4
,
}
impl
MessageType
{
/// Try to convert a u8 to a MessageType
pub
fn
from_u8
(
value
:
u8
)
->
Option
<
Self
>
{
match
value
{
0
=>
Some
(
MessageType
::
Message
),
1
=>
Some
(
MessageType
::
Response
),
2
=>
Some
(
MessageType
::
Ack
),
3
=>
Some
(
MessageType
::
Event
),
4
=>
Some
(
MessageType
::
ShuttingDown
),
_
=>
None
,
}
}
/// Convert MessageType to u8
pub
fn
as_u8
(
self
)
->
u8
{
self
as
u8
}
}
/// Sender-side handle given to transports for routing inbound frames.
///
/// Each transport receives a clone of this adapter during [`Transport::start`]
/// and uses it to forward decoded `(header, payload)` pairs to the appropriate
/// stream based on [`MessageType`].
#[derive(Clone)]
pub
struct
TransportAdapter
{
/// Channel for inbound [`MessageType::Message`] frames.
pub
message_stream
:
flume
::
Sender
<
(
Bytes
,
Bytes
)
>
,
/// Channel for inbound [`MessageType::Response`] and [`MessageType::ShuttingDown`] frames.
pub
response_stream
:
flume
::
Sender
<
(
Bytes
,
Bytes
)
>
,
/// Channel for inbound [`MessageType::Ack`] and [`MessageType::Event`] frames.
pub
event_stream
:
flume
::
Sender
<
(
Bytes
,
Bytes
)
>
,
/// Shared shutdown coordinator for drain-aware routing.
pub
shutdown_state
:
ShutdownState
,
}
/// Receiver-side handle for consuming inbound frames from all transports.
///
/// Returned by [`make_channels`] alongside the corresponding [`TransportAdapter`].
/// Higher layers pull `(header, payload)` pairs from these channels.
pub
struct
DataStreams
{
/// Receiver for inbound message frames.
pub
message_stream
:
flume
::
Receiver
<
(
Bytes
,
Bytes
)
>
,
/// Receiver for inbound response and shutting-down frames.
pub
response_stream
:
flume
::
Receiver
<
(
Bytes
,
Bytes
)
>
,
/// Receiver for inbound ack and event frames.
pub
event_stream
:
flume
::
Receiver
<
(
Bytes
,
Bytes
)
>
,
/// Shared shutdown coordinator.
pub
shutdown_state
:
ShutdownState
,
}
type
DataStreamTuple
=
(
flume
::
Receiver
<
(
Bytes
,
Bytes
)
>
,
flume
::
Receiver
<
(
Bytes
,
Bytes
)
>
,
flume
::
Receiver
<
(
Bytes
,
Bytes
)
>
,
);
impl
DataStreams
{
/// Destructure into the three raw receivers `(message, response, event)`.
pub
fn
into_parts
(
self
)
->
DataStreamTuple
{
(
self
.message_stream
,
self
.response_stream
,
self
.event_stream
)
}
/// Receive a message with an in-flight guard for drain tracking.
///
/// Returns `(header, payload, guard)`. The guard keeps the in-flight counter
/// incremented until it is dropped or `complete()` is called.
pub
async
fn
recv_message_tracked
(
&
self
,
)
->
Result
<
(
Bytes
,
Bytes
,
InFlightGuard
),
flume
::
RecvError
>
{
let
(
header
,
payload
)
=
self
.message_stream
.recv_async
()
.await
?
;
let
guard
=
self
.shutdown_state
.acquire
();
Ok
((
header
,
payload
,
guard
))
}
}
/// Create a matched pair of [`TransportAdapter`] (sender) and [`DataStreams`] (receiver).
///
/// Both sides share the same [`ShutdownState`] so drain coordination is automatic.
pub
fn
make_channels
()
->
(
TransportAdapter
,
DataStreams
)
{
let
shutdown_state
=
ShutdownState
::
new
();
let
(
message_tx
,
message_rx
)
=
flume
::
unbounded
();
let
(
response_tx
,
response_rx
)
=
flume
::
unbounded
();
let
(
event_tx
,
event_rx
)
=
flume
::
unbounded
();
(
TransportAdapter
{
message_stream
:
message_tx
,
response_stream
:
response_tx
,
event_stream
:
event_tx
,
shutdown_state
:
shutdown_state
.clone
(),
},
DataStreams
{
message_stream
:
message_rx
,
response_stream
:
response_rx
,
event_stream
:
event_rx
,
shutdown_state
,
},
)
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
tokio
::
time
::{
sleep
,
timeout
};
#[test]
fn
test_shutdown_state_initial
()
{
let
state
=
ShutdownState
::
new
();
assert
!
(
!
state
.is_draining
());
assert_eq!
(
state
.in_flight_count
(),
0
);
}
#[test]
fn
test_begin_drain_flips_flag
()
{
let
state
=
ShutdownState
::
new
();
state
.begin_drain
();
assert
!
(
state
.is_draining
());
}
#[test]
fn
test_begin_drain_idempotent
()
{
let
state
=
ShutdownState
::
new
();
state
.begin_drain
();
state
.begin_drain
();
assert
!
(
state
.is_draining
());
}
#[test]
fn
test_acquire_increments_inflight
()
{
let
state
=
ShutdownState
::
new
();
let
_
g1
=
state
.acquire
();
assert_eq!
(
state
.in_flight_count
(),
1
);
let
_
g2
=
state
.acquire
();
assert_eq!
(
state
.in_flight_count
(),
2
);
}
#[test]
fn
test_guard_drop_decrements_inflight
()
{
let
state
=
ShutdownState
::
new
();
let
g
=
state
.acquire
();
assert_eq!
(
state
.in_flight_count
(),
1
);
drop
(
g
);
assert_eq!
(
state
.in_flight_count
(),
0
);
}
#[test]
fn
test_guard_complete_decrements
()
{
let
state
=
ShutdownState
::
new
();
let
g
=
state
.acquire
();
assert_eq!
(
state
.in_flight_count
(),
1
);
g
.complete
();
assert_eq!
(
state
.in_flight_count
(),
0
);
}
#[tokio::test]
async
fn
test_wait_for_drain_immediate
()
{
let
state
=
ShutdownState
::
new
();
// Should complete immediately since in_flight is 0
timeout
(
Duration
::
from_millis
(
100
),
state
.wait_for_drain
())
.await
.expect
(
"wait_for_drain should complete immediately when in_flight is 0"
);
}
#[tokio::test]
async
fn
test_wait_for_drain_blocks_then_completes
()
{
let
state
=
ShutdownState
::
new
();
let
guard
=
state
.acquire
();
let
state_clone
=
state
.clone
();
let
handle
=
tokio
::
spawn
(
async
move
{
state_clone
.wait_for_drain
()
.await
;
});
// Give the waiter time to park
sleep
(
Duration
::
from_millis
(
50
))
.await
;
assert
!
(
!
handle
.is_finished
());
// Drop guard → should unblock
drop
(
guard
);
timeout
(
Duration
::
from_millis
(
100
),
handle
)
.await
.expect
(
"should complete after guard drop"
)
.unwrap
();
}
#[tokio::test]
async
fn
test_multiple_guards_concurrent
()
{
let
state
=
ShutdownState
::
new
();
let
guards
:
Vec
<
_
>
=
(
0
..
10
)
.map
(|
_
|
state
.acquire
())
.collect
();
assert_eq!
(
state
.in_flight_count
(),
10
);
let
state_clone
=
state
.clone
();
let
handle
=
tokio
::
spawn
(
async
move
{
state_clone
.wait_for_drain
()
.await
;
});
// Drop all guards
drop
(
guards
);
timeout
(
Duration
::
from_millis
(
100
),
handle
)
.await
.expect
(
"should complete after all guards drop"
)
.unwrap
();
assert_eq!
(
state
.in_flight_count
(),
0
);
}
#[tokio::test]
async
fn
test_drain_with_zero_inflight
()
{
let
state
=
ShutdownState
::
new
();
state
.begin_drain
();
// Should complete immediately
timeout
(
Duration
::
from_millis
(
100
),
state
.wait_for_drain
())
.await
.expect
(
"should complete immediately with zero in-flight"
);
}
#[test]
fn
test_acquire_works_after_drain
()
{
let
state
=
ShutdownState
::
new
();
state
.begin_drain
();
let
_
g
=
state
.acquire
();
assert_eq!
(
state
.in_flight_count
(),
1
);
}
#[test]
fn
test_guard_drop_during_panic
()
{
let
state
=
ShutdownState
::
new
();
let
result
=
std
::
panic
::
catch_unwind
(
std
::
panic
::
AssertUnwindSafe
(||
{
let
_
g
=
state
.acquire
();
assert_eq!
(
state
.in_flight_count
(),
1
);
panic!
(
"intentional panic"
);
}));
assert
!
(
result
.is_err
());
// Guard's Drop should have fired even during unwind
assert_eq!
(
state
.in_flight_count
(),
0
);
}
#[test]
fn
test_shutting_down_from_u8
()
{
assert_eq!
(
MessageType
::
from_u8
(
4
),
Some
(
MessageType
::
ShuttingDown
));
}
#[test]
fn
test_shutting_down_as_u8
()
{
assert_eq!
(
MessageType
::
ShuttingDown
.as_u8
(),
4
);
}
#[test]
fn
test_unknown_message_type_still_none
()
{
assert_eq!
(
MessageType
::
from_u8
(
5
),
None
);
assert_eq!
(
MessageType
::
from_u8
(
255
),
None
);
}
#[test]
fn
test_make_channels_includes_shutdown_state
()
{
let
(
adapter
,
streams
)
=
make_channels
();
// Both sides should share the same ShutdownState (via Arc)
assert
!
(
!
adapter
.shutdown_state
.is_draining
());
assert
!
(
!
streams
.shutdown_state
.is_draining
());
// Mutating one should be visible through the other
adapter
.shutdown_state
.begin_drain
();
assert
!
(
streams
.shutdown_state
.is_draining
());
}
#[tokio::test]
async
fn
test_recv_message_tracked_returns_guard
()
{
let
(
adapter
,
streams
)
=
make_channels
();
// Send a message through the adapter
adapter
.message_stream
.send_async
((
bytes
::
Bytes
::
from_static
(
b
"hdr"
),
bytes
::
Bytes
::
from_static
(
b
"pay"
),
))
.await
.unwrap
();
// Receive with tracking
let
(
header
,
payload
,
guard
)
=
streams
.recv_message_tracked
()
.await
.unwrap
();
assert_eq!
(
&
header
[
..
],
b
"hdr"
);
assert_eq!
(
&
payload
[
..
],
b
"pay"
);
assert_eq!
(
streams
.shutdown_state
.in_flight_count
(),
1
);
// Drop guard
drop
(
guard
);
assert_eq!
(
streams
.shutdown_state
.in_flight_count
(),
0
);
}
#[test]
fn
test_shutdown_state_clone_shares_inner
()
{
let
s1
=
ShutdownState
::
new
();
let
s2
=
s1
.clone
();
s1
.begin_drain
();
assert
!
(
s2
.is_draining
());
let
_
g
=
s1
.acquire
();
assert_eq!
(
s2
.in_flight_count
(),
1
);
}
#[test]
fn
test_teardown_token
()
{
let
state
=
ShutdownState
::
new
();
assert
!
(
!
state
.teardown_token
()
.is_cancelled
());
state
.teardown_token
()
.cancel
();
assert
!
(
state
.teardown_token
()
.is_cancelled
());
}
}
lib/velo-transports/src/utils/mod.rs
0 → 100644
View file @
63d7c01c
lib/velo-transports/tests/common/mod.rs
0 → 100644
View file @
63d7c01c
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Common test utilities for transport integration tests
//!
//! This module provides a transport-agnostic test infrastructure that can be reused
//! across different transport implementations (TCP, RDMA, UDP, UDS, etc.).
#![allow(dead_code)]
// #[cfg(feature = "grpc")]
// use velo_transports::grpc::{GrpcTransport, GrpcTransportBuilder};
// #[cfg(feature = "http")]
// use velo_transports::http::{HttpTransport, HttpTransportBuilder};
// #[cfg(feature = "nats")]
// use velo_transports::nats::{NatsTransport, NatsTransportBuilder};
// #[cfg(feature = "ucx")]
// use velo_transports::ucx::{UcxTransport, UcxTransportBuilder};
use
bytes
::
Bytes
;
use
std
::
sync
::{
Arc
,
Mutex
};
use
std
::
time
::
Duration
;
use
tokio
::
time
::
timeout
;
use
velo_transports
::{
DataStreams
,
InstanceId
,
MessageType
,
PeerInfo
,
Transport
,
TransportErrorHandler
,
tcp
::{
TcpTransport
,
TcpTransportBuilder
},
};
use
std
::
sync
::
Once
;
use
tracing_subscriber
::
FmtSubscriber
;
#[allow(dead_code)]
static
INIT
:
Once
=
Once
::
new
();
#[allow(dead_code)]
pub
fn
init_tracing
()
{
INIT
.call_once
(||
{
let
_
=
FmtSubscriber
::
builder
()
.with_env_filter
(
"trace"
)
// or "info"
.try_init
();
});
}
pub
mod
scenarios
;
/// Test error handler that tracks errors for verification
#[derive(Clone)]
pub
struct
TestErrorHandler
{
errors
:
Arc
<
Mutex
<
Vec
<
(
Bytes
,
Bytes
,
String
)
>>>
,
}
impl
TestErrorHandler
{
pub
fn
new
()
->
Self
{
Self
{
errors
:
Arc
::
new
(
Mutex
::
new
(
Vec
::
new
())),
}
}
pub
fn
get_errors
(
&
self
)
->
Vec
<
(
Bytes
,
Bytes
,
String
)
>
{
self
.errors
.lock
()
.unwrap
()
.clone
()
}
pub
fn
error_count
(
&
self
)
->
usize
{
self
.errors
.lock
()
.unwrap
()
.len
()
}
pub
fn
clear
(
&
self
)
{
self
.errors
.lock
()
.unwrap
()
.clear
();
}
}
impl
TransportErrorHandler
for
TestErrorHandler
{
fn
on_error
(
&
self
,
header
:
Bytes
,
payload
:
Bytes
,
error
:
String
)
{
self
.errors
.lock
()
.unwrap
()
.push
((
header
,
payload
,
error
));
}
}
/// Handle to a transport instance with its streams for testing
///
/// This is a generic test handle that works with any transport implementation.
/// Use `TestTransportHandle::with_factory()` to create instances with custom transports,
/// or use convenience methods like `TestTransportHandle::new()` for TCP transport.
pub
struct
TestTransportHandle
<
T
:
Transport
>
{
pub
transport
:
T
,
pub
streams
:
DataStreams
,
pub
instance_id
:
InstanceId
,
pub
error_handler
:
Arc
<
TestErrorHandler
>
,
runtime
:
tokio
::
runtime
::
Handle
,
}
impl
<
T
:
Transport
>
TestTransportHandle
<
T
>
{
/// Create a new test transport using a factory function
///
/// This is the generic constructor that works with any transport implementation.
/// The factory function should create and return a transport instance.
///
/// # Example
/// ```ignore
/// let handle = TestTransportHandle::with_factory(|| {
/// MyTransportBuilder::new().build()
/// }).await?;
/// ```
pub
async
fn
with_factory
<
F
>
(
factory
:
F
)
->
anyhow
::
Result
<
Self
>
where
F
:
FnOnce
()
->
anyhow
::
Result
<
T
>
,
{
let
transport
=
factory
()
?
;
let
instance_id
=
InstanceId
::
new_v4
();
let
error_handler
=
Arc
::
new
(
TestErrorHandler
::
new
());
// Create channels for this transport
let
(
adapter
,
streams
)
=
velo_transports
::
make_channels
();
// Get runtime handle
let
runtime
=
tokio
::
runtime
::
Handle
::
current
();
// Start the transport
transport
.start
(
instance_id
,
adapter
,
runtime
.clone
())
.await
?
;
// Give the listener a moment to bind and start accepting connections
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
50
))
.await
;
Ok
(
Self
{
transport
,
streams
,
instance_id
,
error_handler
,
runtime
,
})
}
/// Register another transport as a peer
pub
fn
register_peer
<
U
:
Transport
>
(
&
self
,
other
:
&
TestTransportHandle
<
U
>
,
)
->
anyhow
::
Result
<
()
>
{
let
peer_info
=
PeerInfo
::
new
(
other
.instance_id
,
other
.transport
.address
());
self
.transport
.register
(
peer_info
)
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to register peer: {:?}"
,
e
))
?
;
Ok
(())
}
/// Send a message to a peer
pub
fn
send
(
&
self
,
target
:
InstanceId
,
header
:
Vec
<
u8
>
,
payload
:
Vec
<
u8
>
,
msg_type
:
MessageType
,
)
{
self
.transport
.send_message
(
target
,
header
,
payload
,
msg_type
,
self
.error_handler
.clone
(),
);
}
/// Receive a message with timeout
pub
async
fn
recv_message
(
&
self
,
timeout_duration
:
Duration
)
->
anyhow
::
Result
<
(
Bytes
,
Bytes
)
>
{
timeout
(
timeout_duration
,
self
.streams.message_stream
.recv_async
())
.await
.map_err
(|
_
|
anyhow
::
anyhow!
(
"Timeout waiting for message"
))
?
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Channel error: {}"
,
e
))
}
/// Receive a response with timeout
pub
async
fn
recv_response
(
&
self
,
timeout_duration
:
Duration
,
)
->
anyhow
::
Result
<
(
Bytes
,
Bytes
)
>
{
timeout
(
timeout_duration
,
self
.streams.response_stream
.recv_async
())
.await
.map_err
(|
_
|
anyhow
::
anyhow!
(
"Timeout waiting for response"
))
?
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Channel error: {}"
,
e
))
}
/// Receive an event with timeout
pub
async
fn
recv_event
(
&
self
,
timeout_duration
:
Duration
)
->
anyhow
::
Result
<
(
Bytes
,
Bytes
)
>
{
timeout
(
timeout_duration
,
self
.streams.event_stream
.recv_async
())
.await
.map_err
(|
_
|
anyhow
::
anyhow!
(
"Timeout waiting for event"
))
?
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Channel error: {}"
,
e
))
}
/// Collect multiple messages with timeout
pub
async
fn
collect_messages
(
&
self
,
count
:
usize
,
timeout_duration
:
Duration
,
)
->
anyhow
::
Result
<
Vec
<
(
Bytes
,
Bytes
)
>>
{
let
mut
messages
=
Vec
::
new
();
for
_
in
0
..
count
{
messages
.push
(
self
.recv_message
(
timeout_duration
)
.await
?
);
}
Ok
(
messages
)
}
/// Collect multiple messages with timeout, sorted by header for order-independent comparison
///
/// This is useful for testing transports that don't guarantee delivery order (e.g., HTTP).
/// Messages are sorted by header bytes to enable deterministic comparison regardless of
/// delivery order.
pub
async
fn
collect_messages_unordered
(
&
self
,
count
:
usize
,
timeout_duration
:
Duration
,
)
->
anyhow
::
Result
<
Vec
<
(
Bytes
,
Bytes
)
>>
{
let
mut
messages
=
self
.collect_messages
(
count
,
timeout_duration
)
.await
?
;
messages
.sort_by
(|
a
,
b
|
a
.0
.cmp
(
&
b
.0
));
Ok
(
messages
)
}
/// Collect multiple responses with timeout
pub
async
fn
collect_responses
(
&
self
,
count
:
usize
,
timeout_duration
:
Duration
,
)
->
anyhow
::
Result
<
Vec
<
(
Bytes
,
Bytes
)
>>
{
let
mut
responses
=
Vec
::
new
();
for
_
in
0
..
count
{
responses
.push
(
self
.recv_response
(
timeout_duration
)
.await
?
);
}
Ok
(
responses
)
}
/// Shutdown the transport
pub
fn
shutdown
(
self
)
{
self
.transport
.shutdown
();
}
}
// TCP-specific convenience constructors
impl
TestTransportHandle
<
TcpTransport
>
{
/// Create a new TCP transport on a random available port
///
/// This is a convenience method for creating TCP transports.
/// For other transport types, use `with_factory()`.
pub
async
fn
new_tcp
()
->
anyhow
::
Result
<
Self
>
{
Self
::
with_factory
(||
{
let
listener
=
std
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
?
;
TcpTransportBuilder
::
new
()
.from_listener
(
listener
)
?
.build
()
})
.await
}
/// Alias for `new_tcp()` to maintain backward compatibility
pub
async
fn
new
()
->
anyhow
::
Result
<
Self
>
{
Self
::
new_tcp
()
.await
}
}
// // UCX-specific convenience constructors
// #[cfg(feature = "ucx")]
// impl TestTransportHandle<UcxTransport> {
// /// Create a new UCX transport
// ///
// /// This is a convenience method for creating UCX transports.
// /// For other transport types, use `with_factory()`.
// pub async fn new_ucx() -> anyhow::Result<Self> {
// Self::with_factory(|| UcxTransportBuilder::new().build()).await
// }
// }
// // HTTP-specific convenience constructors
// #[cfg(feature = "http")]
// impl TestTransportHandle<HttpTransport> {
// /// Create a new HTTP transport with OS-provided port
// ///
// /// This is a convenience method for creating HTTP transports.
// /// For other transport types, use `with_factory()`.
// pub async fn new_http() -> anyhow::Result<Self> {
// Self::with_factory(|| {
// // Use default builder which binds to 0.0.0.0:0 (OS-provided port)
// HttpTransportBuilder::new().build()
// })
// .await
// }
// }
// // NATS-specific convenience constructor
// #[cfg(feature = "nats")]
// impl TestTransportHandle<NatsTransport> {
// /// Create a new NATS transport
// ///
// /// This is a convenience method for creating NATS transports.
// /// For other transport types, use `with_factory()`.
// ///
// /// Note: NATS transport requires special handling because it needs the instance_id
// /// at construction time to set up subject subscriptions. We can't use the generic
// /// with_factory() because it creates the instance_id AFTER calling the factory.
// pub async fn new_nats() -> anyhow::Result<Self> {
// // Create instance_id
// let instance_id = InstanceId::new_v4();
// let error_handler = Arc::new(TestErrorHandler::new());
// // Build transport
// let transport = NatsTransportBuilder::new()
// .nats_url("nats://127.0.0.1:4222")
// .build()?;
// // Create channels for this transport
// let (adapter, streams) = velo_transports::make_channels();
// // Get runtime handle
// let runtime = tokio::runtime::Handle::current();
// // Start the transport
// transport
// .start(instance_id, adapter, runtime.clone())
// .await?;
// // Give NATS a moment to establish subscriptions
// tokio::time::sleep(Duration::from_millis(50)).await;
// Ok(Self {
// transport,
// streams,
// instance_id,
// error_handler,
// runtime,
// })
// }
// }
// // gRPC-specific convenience constructors
// #[cfg(feature = "grpc")]
// impl TestTransportHandle<GrpcTransport> {
// /// Create a new gRPC transport with OS-provided port
// ///
// /// This is a convenience method for creating gRPC transports.
// /// For other transport types, use `with_factory()`.
// pub async fn new_grpc() -> anyhow::Result<Self> {
// Self::with_factory(|| {
// // Use default builder which binds to 0.0.0.0:0 (OS-provided port)
// GrpcTransportBuilder::new().build()
// })
// .await
// }
// }
/// Multi-transport test cluster
///
/// A generic cluster that works with any transport implementation.
/// All transports in the cluster are registered with each other in a full mesh topology.
pub
struct
TestCluster
<
T
:
Transport
>
{
transports
:
Vec
<
TestTransportHandle
<
T
>>
,
}
impl
<
T
:
Transport
>
TestCluster
<
T
>
{
/// Create a new test cluster using a factory function
///
/// This is the generic constructor that works with any transport implementation.
/// The factory function will be called `size` times to create each transport.
///
/// # Example
/// ```ignore
/// let cluster = TestCluster::with_factory(3, || {
/// MyTransportBuilder::new().build()
/// }).await?;
/// ```
pub
async
fn
with_factory
<
F
>
(
size
:
usize
,
factory
:
F
)
->
anyhow
::
Result
<
Self
>
where
F
:
Fn
()
->
anyhow
::
Result
<
T
>
,
{
let
mut
transports
=
Vec
::
new
();
for
_
in
0
..
size
{
transports
.push
(
TestTransportHandle
::
with_factory
(
&
factory
)
.await
?
);
}
// Register all peers with each other (full mesh)
for
i
in
0
..
transports
.len
()
{
for
j
in
0
..
transports
.len
()
{
if
i
!=
j
{
transports
[
i
]
.register_peer
(
&
transports
[
j
])
?
;
}
}
}
Ok
(
Self
{
transports
})
}
/// Get a transport by index
pub
fn
get
(
&
self
,
index
:
usize
)
->
&
TestTransportHandle
<
T
>
{
&
self
.transports
[
index
]
}
/// Get all transports
pub
fn
all
(
&
self
)
->
&
[
TestTransportHandle
<
T
>
]
{
&
self
.transports
}
/// Shutdown all transports
pub
fn
shutdown
(
self
)
{
for
transport
in
self
.transports
{
transport
.shutdown
();
}
}
}
// TCP-specific convenience constructor
impl
TestCluster
<
TcpTransport
>
{
/// Create a new TCP test cluster with the specified number of transports
///
/// This is a convenience method for creating TCP clusters.
/// For other transport types, use `with_factory()`.
pub
async
fn
new
(
size
:
usize
)
->
anyhow
::
Result
<
Self
>
{
Self
::
with_factory
(
size
,
||
{
let
listener
=
std
::
net
::
TcpListener
::
bind
(
"127.0.0.1:0"
)
?
;
TcpTransportBuilder
::
new
()
.from_listener
(
listener
)
?
.build
()
})
.await
}
}
// UCX-specific convenience constructor
#[cfg(feature
=
"ucx"
)]
impl
TestCluster
<
UcxTransport
>
{
/// Create a new UCX test cluster with the specified number of transports
///
/// This is a convenience method for creating UCX clusters.
/// For other transport types, use `with_factory()`.
pub
async
fn
new_ucx
(
size
:
usize
)
->
anyhow
::
Result
<
Self
>
{
Self
::
with_factory
(
size
,
||
UcxTransportBuilder
::
new
()
.build
())
.await
}
}
// // HTTP-specific convenience constructor
// #[cfg(feature = "http")]
// impl TestCluster<HttpTransport> {
// /// Create a new HTTP test cluster with the specified number of transports
// ///
// /// This is a convenience method for creating HTTP clusters.
// /// For other transport types, use `with_factory()`.
// pub async fn new_http(size: usize) -> anyhow::Result<Self> {
// Self::with_factory(size, || {
// // Use default builder which binds to OS-provided ports
// HttpTransportBuilder::new().build()
// })
// .await
// }
// }
// // NATS-specific convenience constructor
// #[cfg(feature = "nats")]
// impl TestCluster<NatsTransport> {
// /// Create a new NATS test cluster with the specified number of transports
// ///
// /// This is a convenience method for creating NATS clusters.
// /// For other transport types, use `with_factory()`.
// ///
// /// Note: NATS transport requires special handling because it needs the instance_id
// /// at construction time. We can't use the generic with_factory() which creates
// /// instance_id after calling the factory function.
// pub async fn new_nats(size: usize) -> anyhow::Result<Self> {
// let mut transports = Vec::new();
// for _ in 0..size {
// transports.push(TestTransportHandle::new_nats().await?);
// }
// // Register all peers with each other (full mesh)
// for i in 0..transports.len() {
// for j in 0..transports.len() {
// if i != j {
// transports[i].register_peer(&transports[j])?;
// }
// }
// }
// Ok(Self { transports })
// }
// }
// // gRPC-specific convenience constructor
// #[cfg(feature = "grpc")]
// impl TestCluster<GrpcTransport> {
// /// Create a new gRPC test cluster with the specified number of transports
// ///
// /// This is a convenience method for creating gRPC clusters.
// /// For other transport types, use `with_factory()`.
// pub async fn new_grpc(size: usize) -> anyhow::Result<Self> {
// Self::with_factory(size, || {
// // Use default builder which binds to OS-provided ports
// GrpcTransportBuilder::new().build()
// })
// .await
// }
// }
// Helper utilities
/// Get a random available port
pub
fn
get_random_port
()
->
u16
{
use
std
::
net
::
TcpListener
;
let
listener
=
TcpListener
::
bind
(
"127.0.0.1:0"
)
.unwrap
();
listener
.local_addr
()
.unwrap
()
.port
()
}
/// Create test data with the specified size
pub
fn
test_data
(
size
:
usize
)
->
Vec
<
u8
>
{
(
0
..
size
)
.map
(|
i
|
(
i
%
256
)
as
u8
)
.collect
()
}
/// Create a test message with predictable content
pub
fn
test_message
(
id
:
u32
)
->
(
Vec
<
u8
>
,
Vec
<
u8
>
)
{
let
header
=
format!
(
"header-{}"
,
id
)
.into_bytes
();
let
payload
=
format!
(
"payload-{}"
,
id
)
.into_bytes
();
(
header
,
payload
)
}
/// Assert that a received message matches expected values
pub
fn
assert_message_eq
(
received
:
(
Bytes
,
Bytes
),
expected_header
:
&
[
u8
],
expected_payload
:
&
[
u8
],
)
{
assert_eq!
(
received
.0
.as_ref
(),
expected_header
,
"Header mismatch"
);
assert_eq!
(
received
.1
.as_ref
(),
expected_payload
,
"Payload mismatch"
);
}
// Transport factory abstraction for parameterized tests
/// Transport factory trait for creating transports in parameterized tests
pub
trait
TransportFactory
{
type
Transport
:
Transport
;
async
fn
create
()
->
anyhow
::
Result
<
TestTransportHandle
<
Self
::
Transport
>>
;
async
fn
create_cluster
(
size
:
usize
)
->
anyhow
::
Result
<
TestCluster
<
Self
::
Transport
>>
;
}
/// TCP transport factory
pub
struct
TcpFactory
;
impl
TransportFactory
for
TcpFactory
{
type
Transport
=
TcpTransport
;
async
fn
create
()
->
anyhow
::
Result
<
TestTransportHandle
<
Self
::
Transport
>>
{
TestTransportHandle
::
new_tcp
()
.await
}
async
fn
create_cluster
(
size
:
usize
)
->
anyhow
::
Result
<
TestCluster
<
Self
::
Transport
>>
{
TestCluster
::
new
(
size
)
.await
}
}
// /// UCX transport factory
// #[cfg(feature = "ucx")]
// pub struct UcxFactory;
// #[cfg(feature = "ucx")]
// impl TransportFactory for UcxFactory {
// type Transport = UcxTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_ucx().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_ucx(size).await
// }
// }
// /// HTTP transport factory
// #[cfg(feature = "http")]
// pub struct HttpFactory;
// #[cfg(feature = "http")]
// impl TransportFactory for HttpFactory {
// type Transport = HttpTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_http().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_http(size).await
// }
// }
// /// NATS transport factory
// #[cfg(feature = "nats")]
// pub struct NatsFactory;
// #[cfg(feature = "nats")]
// impl TransportFactory for NatsFactory {
// type Transport = NatsTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_nats().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_nats(size).await
// }
// }
// /// gRPC transport factory
// #[cfg(feature = "grpc")]
// pub struct GrpcFactory;
// #[cfg(feature = "grpc")]
// impl TransportFactory for GrpcFactory {
// type Transport = GrpcTransport;
// async fn create() -> anyhow::Result<TestTransportHandle<Self::Transport>> {
// TestTransportHandle::new_grpc().await
// }
// async fn create_cluster(size: usize) -> anyhow::Result<TestCluster<Self::Transport>> {
// TestCluster::new_grpc(size).await
// }
// }
lib/velo-transports/tests/common/scenarios.rs
0 → 100644
View file @
63d7c01c
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Generic test scenarios that work with any transport implementation
use
super
::
*
;
use
std
::
time
::
Duration
;
const
TEST_TIMEOUT
:
Duration
=
Duration
::
from_secs
(
5
);
pub
async
fn
single_message_round_trip
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Message
,
);
let
received
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
bidirectional_messaging
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
transport_b
.register_peer
(
&
transport_a
)
.unwrap
();
// A -> B
let
(
header1
,
payload1
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header1
.clone
(),
payload1
.clone
(),
MessageType
::
Message
,
);
// B -> A
let
(
header2
,
payload2
)
=
test_message
(
2
);
transport_b
.send
(
transport_a
.instance_id
,
header2
.clone
(),
payload2
.clone
(),
MessageType
::
Message
,
);
let
recv_b
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
let
recv_a
=
transport_a
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
recv_b
,
&
header1
,
&
payload1
);
assert_message_eq
(
recv_a
,
&
header2
,
&
payload2
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
multiple_messages_same_connection
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// Send 10 messages
for
i
in
0
..
10
{
let
(
header
,
payload
)
=
test_message
(
i
);
transport_a
.send
(
transport_b
.instance_id
,
header
,
payload
,
MessageType
::
Message
,
);
}
// Receive and verify all messages (order-independent)
let
messages
=
transport_b
.collect_messages_unordered
(
10
,
TEST_TIMEOUT
)
.await
.unwrap
();
// Generate expected messages and sort them the same way
let
mut
expected
:
Vec
<
_
>
=
(
0
..
10
)
.map
(
test_message
)
.collect
();
expected
.sort_by
(|
a
,
b
|
a
.0
.cmp
(
&
b
.0
));
for
(
i
,
msg
)
in
messages
.iter
()
.enumerate
()
{
assert_message_eq
(
msg
.clone
(),
&
expected
[
i
]
.0
,
&
expected
[
i
]
.1
);
}
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
response_message_type
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Response
,
);
let
received
=
transport_b
.recv_response
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
event_message_type
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Event
,
);
let
received
=
transport_b
.recv_event
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
ack_message_type
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Ack
,
);
// Acks route to event stream
let
received
=
transport_b
.recv_event
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
mixed_message_types
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// Send different message types
let
(
msg_h
,
msg_p
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
msg_h
.clone
(),
msg_p
.clone
(),
MessageType
::
Message
,
);
let
(
resp_h
,
resp_p
)
=
test_message
(
2
);
transport_a
.send
(
transport_b
.instance_id
,
resp_h
.clone
(),
resp_p
.clone
(),
MessageType
::
Response
,
);
let
(
event_h
,
event_p
)
=
test_message
(
3
);
transport_a
.send
(
transport_b
.instance_id
,
event_h
.clone
(),
event_p
.clone
(),
MessageType
::
Event
,
);
// Receive from appropriate streams
let
recv_msg
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
let
recv_resp
=
transport_b
.recv_response
(
TEST_TIMEOUT
)
.await
.unwrap
();
let
recv_event
=
transport_b
.recv_event
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
recv_msg
,
&
msg_h
,
&
msg_p
);
assert_message_eq
(
recv_resp
,
&
resp_h
,
&
resp_p
);
assert_message_eq
(
recv_event
,
&
event_h
,
&
event_p
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
large_payload
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// 1MB payload
let
header
=
b
"large-payload"
.to_vec
();
let
payload
=
test_data
(
1024
*
1024
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Message
,
);
let
received
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
empty_header_and_payload
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
transport_a
.send
(
transport_b
.instance_id
,
vec!
[],
vec!
[],
MessageType
::
Message
,
);
let
received
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
[],
&
[]);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
cluster_mesh_communication
<
F
:
TransportFactory
>
()
{
let
cluster
=
F
::
create_cluster
(
3
)
.await
.unwrap
();
// Each node sends to every other node
for
i
in
0
..
3
{
for
j
in
0
..
3
{
if
i
!=
j
{
let
(
header
,
payload
)
=
test_message
((
i
*
10
+
j
)
as
u32
);
cluster
.get
(
i
)
.send
(
cluster
.get
(
j
)
.instance_id
,
header
,
payload
,
MessageType
::
Message
,
);
}
}
}
// Each node should receive 2 messages
for
i
in
0
..
3
{
let
messages
=
cluster
.get
(
i
)
.collect_messages
(
2
,
TEST_TIMEOUT
)
.await
.unwrap
();
assert_eq!
(
messages
.len
(),
2
);
}
cluster
.shutdown
();
}
pub
async
fn
concurrent_senders
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// Send from multiple tasks concurrently (without needing to move transport_a)
let
target_id
=
transport_b
.instance_id
;
let
mut
handles
=
vec!
[];
for
i
in
0
..
10
{
let
(
header
,
payload
)
=
test_message
(
i
);
// Send directly without spawning - the send itself is non-blocking
transport_a
.send
(
target_id
,
header
,
payload
,
MessageType
::
Message
);
}
// Alternatively test with actual concurrent tasks using a different approach
// Spawn receiver tasks to demonstrate concurrent receives
for
_
in
0
..
10
{
let
handle
=
tokio
::
spawn
(
async
{
// Just to demonstrate concurrency is working
tokio
::
time
::
sleep
(
Duration
::
from_micros
(
1
))
.await
;
});
handles
.push
(
handle
);
}
// Wait for all tasks
for
handle
in
handles
{
handle
.await
.unwrap
();
}
// Receive all messages
let
messages
=
transport_b
.collect_messages
(
10
,
TEST_TIMEOUT
)
.await
.unwrap
();
assert_eq!
(
messages
.len
(),
10
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
send_to_unregistered_peer
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
// Don't register B with A
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Message
,
);
// Give it a moment to process
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
// Should have an error
assert
!
(
transport_a
.error_handler
.error_count
()
>
0
);
let
errors
=
transport_a
.error_handler
.get_errors
();
assert_eq!
(
errors
.len
(),
1
);
assert_eq!
(
errors
[
0
]
.0
,
header
.as_slice
());
assert_eq!
(
errors
[
0
]
.1
,
payload
.as_slice
());
assert
!
(
errors
[
0
]
.2
.to_lowercase
()
.contains
(
"peer not registered"
));
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
connection_reuse
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// First message establishes connection
let
(
header1
,
payload1
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header1
.clone
(),
payload1
.clone
(),
MessageType
::
Message
,
);
let
recv1
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
recv1
,
&
header1
,
&
payload1
);
// Second message reuses connection
let
(
header2
,
payload2
)
=
test_message
(
2
);
transport_a
.send
(
transport_b
.instance_id
,
header2
.clone
(),
payload2
.clone
(),
MessageType
::
Message
,
);
let
recv2
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
recv2
,
&
header2
,
&
payload2
);
// No errors should have occurred
assert_eq!
(
transport_a
.error_handler
.error_count
(),
0
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
graceful_shutdown
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// Send a message
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Message
,
);
// Receive it
let
received
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
// Shutdown should complete without panics
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
high_throughput
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
let
num_messages
=
100
;
// Send many messages
for
i
in
0
..
num_messages
{
let
(
header
,
payload
)
=
test_message
(
i
);
transport_a
.send
(
transport_b
.instance_id
,
header
,
payload
,
MessageType
::
Message
,
);
}
// Receive all messages (order-independent)
let
messages
=
transport_b
.collect_messages_unordered
(
num_messages
as
usize
,
TEST_TIMEOUT
)
.await
.unwrap
();
assert_eq!
(
messages
.len
(),
num_messages
as
usize
);
// Generate expected messages and sort them the same way
let
mut
expected
:
Vec
<
_
>
=
(
0
..
num_messages
)
.map
(
test_message
)
.collect
();
expected
.sort_by
(|
a
,
b
|
a
.0
.cmp
(
&
b
.0
));
// Verify all messages received correctly
for
(
i
,
msg
)
in
messages
.iter
()
.enumerate
()
{
assert_message_eq
(
msg
.clone
(),
&
expected
[
i
]
.0
,
&
expected
[
i
]
.1
);
}
transport_a
.shutdown
();
transport_b
.shutdown
();
}
pub
async
fn
zero_copy_efficiency
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// Large payload to test zero-copy
let
header
=
b
"zero-copy-test"
.to_vec
();
let
payload
=
test_data
(
512
*
1024
);
// 512KB
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Message
,
);
let
received
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
// Verify no errors
assert_eq!
(
transport_a
.error_handler
.error_count
(),
0
);
transport_a
.shutdown
();
transport_b
.shutdown
();
}
// --- Drain / shutdown scenarios ---
/// After begin_drain on B, messages sent from A to B should NOT arrive on B's message_stream.
pub
async
fn
drain_rejects_messages
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// Begin drain on B (both transport-level and shutdown-state, mirroring VeloBackend::graceful_shutdown)
transport_b
.transport
.begin_drain
();
transport_b
.streams.shutdown_state
.begin_drain
();
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
// A sends a Message to B
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
,
payload
,
MessageType
::
Message
,
);
// B's message_stream should be empty (message rejected during drain)
let
result
=
tokio
::
time
::
timeout
(
Duration
::
from_millis
(
500
),
transport_b
.streams.message_stream
.recv_async
(),
)
.await
;
assert
!
(
result
.is_err
(),
"Expected timeout — messages should be rejected during drain"
);
transport_a
.transport
.shutdown
();
transport_b
.streams.shutdown_state
.teardown_token
()
.cancel
();
transport_b
.transport
.shutdown
();
}
/// After begin_drain on B, responses sent from A to B should still arrive on B's response_stream.
pub
async
fn
drain_accepts_responses
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// Begin drain on B
transport_b
.transport
.begin_drain
();
transport_b
.streams.shutdown_state
.begin_drain
();
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
// A sends a Response to B
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Response
,
);
// B's response_stream should still receive it
let
received
=
transport_b
.recv_response
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
transport_a
.transport
.shutdown
();
transport_b
.streams.shutdown_state
.teardown_token
()
.cancel
();
transport_b
.transport
.shutdown
();
}
/// After begin_drain on B, events sent from A to B should still arrive on B's event_stream.
pub
async
fn
drain_accepts_events
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// Begin drain on B
transport_b
.transport
.begin_drain
();
transport_b
.streams.shutdown_state
.begin_drain
();
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
// A sends an Event to B
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Event
,
);
// B's event_stream should still receive it
let
received
=
transport_b
.recv_event
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
transport_a
.transport
.shutdown
();
transport_b
.streams.shutdown_state
.teardown_token
()
.cancel
();
transport_b
.transport
.shutdown
();
}
/// After begin_drain on B, health checks from A to B should still succeed.
pub
async
fn
health_during_drain
<
F
:
TransportFactory
>
()
{
let
transport_a
=
F
::
create
()
.await
.unwrap
();
let
transport_b
=
F
::
create
()
.await
.unwrap
();
transport_a
.register_peer
(
&
transport_b
)
.unwrap
();
// Establish a connection first: send a message and receive it
let
(
header
,
payload
)
=
test_message
(
1
);
transport_a
.send
(
transport_b
.instance_id
,
header
.clone
(),
payload
.clone
(),
MessageType
::
Message
,
);
let
received
=
transport_b
.recv_message
(
TEST_TIMEOUT
)
.await
.unwrap
();
assert_message_eq
(
received
,
&
header
,
&
payload
);
// Begin drain on B
transport_b
.transport
.begin_drain
();
transport_b
.streams.shutdown_state
.begin_drain
();
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
// A checks health of B — should still succeed during drain
let
result
=
transport_a
.transport
.check_health
(
transport_b
.instance_id
,
Duration
::
from_secs
(
2
))
.await
;
assert
!
(
result
.is_ok
(),
"Health check should succeed during drain: {:?}"
,
result
.err
()
);
transport_a
.transport
.shutdown
();
transport_b
.streams.shutdown_state
.teardown_token
()
.cancel
();
transport_b
.transport
.shutdown
();
}
lib/velo-transports/tests/tcp_integration.rs
0 → 100644
View file @
63d7c01c
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for TCP transport
mod
common
;
use
common
::{
TcpFactory
,
scenarios
};
#[tokio::test]
async
fn
test_single_message_round_trip
()
{
scenarios
::
single_message_round_trip
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_bidirectional_messaging
()
{
scenarios
::
bidirectional_messaging
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_multiple_messages_same_connection
()
{
scenarios
::
multiple_messages_same_connection
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_response_message_type
()
{
scenarios
::
response_message_type
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_event_message_type
()
{
scenarios
::
event_message_type
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_ack_message_type
()
{
scenarios
::
ack_message_type
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_mixed_message_types
()
{
scenarios
::
mixed_message_types
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_large_payload
()
{
scenarios
::
large_payload
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_empty_header_and_payload
()
{
scenarios
::
empty_header_and_payload
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_cluster_mesh_communication
()
{
scenarios
::
cluster_mesh_communication
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_concurrent_senders
()
{
scenarios
::
concurrent_senders
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_send_to_unregistered_peer
()
{
scenarios
::
send_to_unregistered_peer
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_connection_reuse
()
{
scenarios
::
connection_reuse
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_graceful_shutdown
()
{
scenarios
::
graceful_shutdown
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_high_throughput
()
{
scenarios
::
high_throughput
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_zero_copy_efficiency
()
{
scenarios
::
zero_copy_efficiency
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_drain_rejects_messages
()
{
scenarios
::
drain_rejects_messages
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_drain_accepts_responses
()
{
scenarios
::
drain_accepts_responses
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_drain_accepts_events
()
{
scenarios
::
drain_accepts_events
::
<
TcpFactory
>
()
.await
;
}
#[tokio::test]
async
fn
test_health_during_drain
()
{
scenarios
::
health_during_drain
::
<
TcpFactory
>
()
.await
;
}
lib/velo-transports/tests/tcp_shutdown.rs
0 → 100644
View file @
63d7c01c
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for TCP graceful shutdown
//!
//! These tests verify the 3-phase shutdown behavior:
//! 1. Gate: new Message frames are rejected with ShuttingDown
//! 2. Drain: in-flight work completes, responses/events still flow
//! 3. Teardown: listener and writer tasks exit
mod
common
;
use
bytes
::
Bytes
;
use
std
::
time
::
Duration
;
use
tokio
::
time
::{
sleep
,
timeout
};
use
velo_transports
::
tcp
::
TcpFrameCodec
;
use
velo_transports
::{
MessageType
,
Transport
};
use
common
::
TestTransportHandle
;
/// Helper: connect a raw TCP client to the transport's bind address and send a frame.
async
fn
connect_and_send_frame
(
addr
:
std
::
net
::
SocketAddr
,
msg_type
:
MessageType
,
header
:
&
[
u8
],
payload
:
&
[
u8
],
)
->
tokio
::
net
::
TcpStream
{
let
mut
stream
=
tokio
::
net
::
TcpStream
::
connect
(
addr
)
.await
.unwrap
();
TcpFrameCodec
::
encode_frame
(
&
mut
stream
,
msg_type
,
header
,
payload
)
.await
.unwrap
();
stream
}
/// Helper: read one frame from a raw TCP stream.
async
fn
read_one_frame
(
stream
:
&
mut
tokio
::
net
::
TcpStream
)
->
(
MessageType
,
Bytes
,
Bytes
)
{
use
futures
::
StreamExt
;
use
tokio_util
::
codec
::
Framed
;
let
mut
framed
=
Framed
::
new
(
stream
,
TcpFrameCodec
::
new
());
framed
.next
()
.await
.unwrap
()
.unwrap
()
}
/// Get the bind address from a TcpTransport by parsing its WorkerAddress.
fn
get_bind_addr
(
handle
:
&
TestTransportHandle
<
velo_transports
::
tcp
::
TcpTransport
>
,
)
->
std
::
net
::
SocketAddr
{
let
addr
=
handle
.transport
.address
();
let
key
=
handle
.transport
.key
();
let
endpoint
=
addr
.get_entry
(
&
key
)
.unwrap
()
.unwrap
();
let
s
=
std
::
str
::
from_utf8
(
&
endpoint
)
.unwrap
();
let
s
=
s
.strip_prefix
(
"tcp://"
)
.unwrap_or
(
s
);
s
.parse
()
.unwrap
()
}
// --- Test 18: Drain rejects Message frames ---
#[tokio::test]
async
fn
test_tcp_drain_rejects_messages
()
{
let
handle
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
let
addr
=
get_bind_addr
(
&
handle
);
// Begin drain
handle
.streams.shutdown_state
.begin_drain
();
// Give listener time to be ready
sleep
(
Duration
::
from_millis
(
50
))
.await
;
// Connect and send a Message frame
let
mut
stream
=
connect_and_send_frame
(
addr
,
MessageType
::
Message
,
b
"req-header"
,
b
"req-payload"
)
.await
;
// Should get ShuttingDown back
let
(
msg_type
,
header
,
payload
)
=
read_one_frame
(
&
mut
stream
)
.await
;
assert_eq!
(
msg_type
,
MessageType
::
ShuttingDown
);
assert_eq!
(
&
header
[
..
],
b
"req-header"
);
// Original header echoed back
assert_eq!
(
payload
.len
(),
0
);
// Empty payload
handle
.streams.shutdown_state
.teardown_token
()
.cancel
();
}
// --- Test 19: Drain accepts Response frames ---
#[tokio::test]
async
fn
test_tcp_drain_accepts_responses
()
{
let
handle
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
let
addr
=
get_bind_addr
(
&
handle
);
// Begin drain
handle
.streams.shutdown_state
.begin_drain
();
sleep
(
Duration
::
from_millis
(
50
))
.await
;
// Connect and send a Response frame
connect_and_send_frame
(
addr
,
MessageType
::
Response
,
b
"resp-header"
,
b
"resp-payload"
)
.await
;
// Should arrive on the response stream
let
(
header
,
payload
)
=
timeout
(
Duration
::
from_secs
(
2
),
handle
.streams.response_stream
.recv_async
(),
)
.await
.expect
(
"timeout"
)
.expect
(
"recv"
);
assert_eq!
(
&
header
[
..
],
b
"resp-header"
);
assert_eq!
(
&
payload
[
..
],
b
"resp-payload"
);
handle
.streams.shutdown_state
.teardown_token
()
.cancel
();
}
// --- Test 20: Drain accepts Event frames ---
#[tokio::test]
async
fn
test_tcp_drain_accepts_events
()
{
let
handle
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
let
addr
=
get_bind_addr
(
&
handle
);
handle
.streams.shutdown_state
.begin_drain
();
sleep
(
Duration
::
from_millis
(
50
))
.await
;
connect_and_send_frame
(
addr
,
MessageType
::
Event
,
b
"evt-header"
,
b
"evt-payload"
)
.await
;
let
(
header
,
payload
)
=
timeout
(
Duration
::
from_secs
(
2
),
handle
.streams.event_stream
.recv_async
(),
)
.await
.expect
(
"timeout"
)
.expect
(
"recv"
);
assert_eq!
(
&
header
[
..
],
b
"evt-header"
);
assert_eq!
(
&
payload
[
..
],
b
"evt-payload"
);
handle
.streams.shutdown_state
.teardown_token
()
.cancel
();
}
// --- Test 21: New connection during drain still accepts responses ---
#[tokio::test]
async
fn
test_tcp_new_connection_during_drain
()
{
let
handle
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
let
addr
=
get_bind_addr
(
&
handle
);
// Begin drain BEFORE connecting
handle
.streams.shutdown_state
.begin_drain
();
sleep
(
Duration
::
from_millis
(
50
))
.await
;
// Establish a NEW connection after drain starts
connect_and_send_frame
(
addr
,
MessageType
::
Response
,
b
"new-resp"
,
b
"new-payload"
)
.await
;
// Should arrive on the response stream
let
(
header
,
payload
)
=
timeout
(
Duration
::
from_secs
(
2
),
handle
.streams.response_stream
.recv_async
(),
)
.await
.expect
(
"timeout"
)
.expect
(
"recv"
);
assert_eq!
(
&
header
[
..
],
b
"new-resp"
);
assert_eq!
(
&
payload
[
..
],
b
"new-payload"
);
handle
.streams.shutdown_state
.teardown_token
()
.cancel
();
}
// --- Test 22: ShuttingDown frame roundtrip ---
#[test]
fn
test_shutting_down_frame_roundtrip
()
{
use
bytes
::
BytesMut
;
use
tokio_util
::
codec
::
Decoder
;
let
header
=
b
"correlation-header"
;
let
payload
=
b
""
;
// Encode ShuttingDown frame
let
mut
buf
=
Vec
::
new
();
TcpFrameCodec
::
encode_frame_sync
(
&
mut
buf
,
MessageType
::
ShuttingDown
,
header
,
payload
)
.unwrap
();
// Decode it
let
mut
codec
=
TcpFrameCodec
::
new
();
let
mut
bytes
=
BytesMut
::
from
(
&
buf
[
..
]);
let
(
msg_type
,
decoded_header
,
decoded_payload
)
=
codec
.decode
(
&
mut
bytes
)
.unwrap
()
.unwrap
();
assert_eq!
(
msg_type
,
MessageType
::
ShuttingDown
);
assert_eq!
(
&
decoded_header
[
..
],
header
);
assert_eq!
(
decoded_payload
.len
(),
0
);
}
// --- Test 23: Full graceful shutdown lifecycle ---
#[tokio::test]
async
fn
test_tcp_graceful_shutdown_lifecycle
()
{
let
handle
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
let
addr
=
get_bind_addr
(
&
handle
);
// Verify normal operation: send a message, receive it
connect_and_send_frame
(
addr
,
MessageType
::
Message
,
b
"normal-msg"
,
b
"normal-pay"
)
.await
;
let
(
header
,
_
payload
)
=
timeout
(
Duration
::
from_secs
(
2
),
handle
.streams.message_stream
.recv_async
(),
)
.await
.expect
(
"timeout"
)
.expect
(
"recv"
);
assert_eq!
(
&
header
[
..
],
b
"normal-msg"
);
// Acquire an InFlightGuard (simulate in-progress request)
let
guard
=
handle
.streams.shutdown_state
.acquire
();
assert_eq!
(
handle
.streams.shutdown_state
.in_flight_count
(),
1
);
// Begin drain (Phase 1)
handle
.streams.shutdown_state
.begin_drain
();
sleep
(
Duration
::
from_millis
(
50
))
.await
;
// Verify new messages are rejected
let
mut
stream
=
connect_and_send_frame
(
addr
,
MessageType
::
Message
,
b
"reject-me"
,
b
""
)
.await
;
let
(
msg_type
,
_
,
_
)
=
read_one_frame
(
&
mut
stream
)
.await
;
assert_eq!
(
msg_type
,
MessageType
::
ShuttingDown
);
// Verify responses still flow
connect_and_send_frame
(
addr
,
MessageType
::
Response
,
b
"still-ok"
,
b
"data"
)
.await
;
let
(
header
,
_
)
=
timeout
(
Duration
::
from_secs
(
2
),
handle
.streams.response_stream
.recv_async
(),
)
.await
.expect
(
"timeout"
)
.expect
(
"recv"
);
assert_eq!
(
&
header
[
..
],
b
"still-ok"
);
// Spawn graceful_shutdown in background (will block on drain since guard is held)
let
shutdown_state
=
handle
.streams.shutdown_state
.clone
();
let
shutdown_handle
=
tokio
::
spawn
(
async
move
{
// Phase 2: wait for drain
shutdown_state
.wait_for_drain
()
.await
;
// Phase 3: teardown
shutdown_state
.teardown_token
()
.cancel
();
});
// Verify shutdown hasn't completed yet (guard still held)
sleep
(
Duration
::
from_millis
(
100
))
.await
;
assert
!
(
!
shutdown_handle
.is_finished
());
// Drop guard → drain completes → teardown fires
drop
(
guard
);
timeout
(
Duration
::
from_secs
(
2
),
shutdown_handle
)
.await
.expect
(
"shutdown should complete"
)
.unwrap
();
assert
!
(
handle
.streams
.shutdown_state
.teardown_token
()
.is_cancelled
()
);
}
// --- Test 24: Shutdown timeout forces teardown ---
#[tokio::test]
async
fn
test_tcp_shutdown_timeout_forces_teardown
()
{
let
handle
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
// Acquire guard and hold it
let
_
guard
=
handle
.streams.shutdown_state
.acquire
();
let
shutdown_state
=
handle
.streams.shutdown_state
.clone
();
let
shutdown_handle
=
tokio
::
spawn
(
async
move
{
shutdown_state
.begin_drain
();
// Phase 2: wait with short timeout
let
_
=
tokio
::
time
::
timeout
(
Duration
::
from_millis
(
100
),
shutdown_state
.wait_for_drain
())
.await
;
// Phase 3: teardown (forced, guard still held)
shutdown_state
.teardown_token
()
.cancel
();
});
timeout
(
Duration
::
from_secs
(
2
),
shutdown_handle
)
.await
.expect
(
"shutdown should complete via timeout"
)
.unwrap
();
// Teardown should have fired even though guard is held
assert
!
(
handle
.streams
.shutdown_state
.teardown_token
()
.is_cancelled
()
);
// Guard is still held (not a problem — teardown was forced)
assert_eq!
(
handle
.streams.shutdown_state
.in_flight_count
(),
1
);
}
// --- Test 25: Outbound sends during drain ---
#[tokio::test]
async
fn
test_outbound_sends_during_drain
()
{
// Create two transports and register them as peers
let
handle_a
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
let
handle_b
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
handle_a
.register_peer
(
&
handle_b
)
.unwrap
();
handle_b
.register_peer
(
&
handle_a
)
.unwrap
();
// Begin drain on transport A
handle_a
.streams.shutdown_state
.begin_drain
();
sleep
(
Duration
::
from_millis
(
50
))
.await
;
// Send a Response from A to B (outbound sends should work during drain)
handle_a
.send
(
handle_b
.instance_id
,
b
"response-hdr"
.to_vec
(),
b
"response-pay"
.to_vec
(),
MessageType
::
Response
,
);
// B should receive the response
let
(
header
,
payload
)
=
timeout
(
Duration
::
from_secs
(
2
),
handle_b
.streams.response_stream
.recv_async
(),
)
.await
.expect
(
"timeout"
)
.expect
(
"recv"
);
assert_eq!
(
&
header
[
..
],
b
"response-hdr"
);
assert_eq!
(
&
payload
[
..
],
b
"response-pay"
);
handle_a
.streams.shutdown_state
.teardown_token
()
.cancel
();
handle_b
.streams.shutdown_state
.teardown_token
()
.cancel
();
}
// --- Test 26: Connection writer exits on teardown ---
#[tokio::test]
async
fn
test_connection_writer_exits_on_teardown
()
{
let
handle_a
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
let
handle_b
=
TestTransportHandle
::
new_tcp
()
.await
.unwrap
();
handle_a
.register_peer
(
&
handle_b
)
.unwrap
();
// Send a message to establish the connection writer task
handle_a
.send
(
handle_b
.instance_id
,
b
"setup"
.to_vec
(),
b
"data"
.to_vec
(),
MessageType
::
Message
,
);
// Wait for it to arrive
timeout
(
Duration
::
from_secs
(
2
),
handle_b
.streams.message_stream
.recv_async
(),
)
.await
.expect
(
"timeout"
)
.expect
(
"recv"
);
// Shutdown transport A
handle_a
.transport
.shutdown
();
// Give writer tasks time to exit
sleep
(
Duration
::
from_millis
(
200
))
.await
;
// Sending should now fail (error handler gets invoked, not a panic)
handle_a
.send
(
handle_b
.instance_id
,
b
"should-fail"
.to_vec
(),
b
"data"
.to_vec
(),
MessageType
::
Message
,
);
// Give time for async error path
sleep
(
Duration
::
from_millis
(
100
))
.await
;
// The message either goes to error handler or is silently dropped
// (connection cleared during shutdown). Just verify no panic occurred.
handle_b
.streams.shutdown_state
.teardown_token
()
.cancel
();
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment