"...controller/dynamocomponentdeployment_controller_test.go" did not exist on "c939da0c19dc10c12e38515c67bc8de7f77bb0f3"
Unverified Commit 6633be54 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix(runtime): lock-free TCP connection pool for OOM prevention and...

fix(runtime): lock-free TCP connection pool for OOM prevention and inflight/resource-leak fixes (#7806)
parent 7d46806d
......@@ -2580,6 +2580,7 @@ dependencies = [
"chrono",
"console-subscriber",
"criterion 0.5.1",
"crossbeam-queue",
"cudarc",
"dashmap",
"derive-getters",
......
......@@ -1697,6 +1697,7 @@ dependencies = [
"blake3",
"bytes",
"chrono",
"crossbeam-queue",
"dashmap",
"derive-getters",
"derive_builder",
......
......@@ -1743,6 +1743,7 @@ dependencies = [
"blake3",
"bytes",
"chrono",
"crossbeam-queue",
"cudarc",
"dashmap",
"derive-getters",
......
......@@ -90,46 +90,11 @@ struct RequestGuard {
expected_output_tokens: Option<u32>,
/// Deferred session close action (fires after generation completes)
deferred_close: Option<SessionCloseAction>,
}
struct PendingDispatchGuard {
chooser: Arc<KvRouter>,
scheduler_tracked: bool,
context_id: String,
deferred_close: Option<SessionCloseAction>,
disarmed: bool,
}
fn spawn_cleanup_task(
chooser: &Arc<KvRouter>,
scheduler_tracked: bool,
context_id: &str,
deferred_close: Option<SessionCloseAction>,
log_context: &'static str,
) {
if deferred_close.is_none() && !scheduler_tracked {
return;
}
let Ok(handle) = tokio::runtime::Handle::try_current() else {
tracing::warn!(
"No tokio runtime for {log_context} cleanup of request {}",
context_id
);
return;
};
let chooser = chooser.clone();
let context_id = context_id.to_owned();
handle.spawn(async move {
if scheduler_tracked && let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id} ({log_context}): {e}");
}
if let Some(close) = deferred_close {
close.execute(&context_id);
}
});
/// True once inner.direct() has returned Ok — guards record_metrics() so
/// that a dispatch failure does not emit metrics for a request that never
/// reached the backend (spurious requests_total increment, OSL histogram
/// zeros, premature tracker.record_finish()).
dispatched: bool,
}
impl RequestGuard {
......@@ -223,7 +188,10 @@ impl RequestGuard {
}
fn record_metrics(&mut self) {
if self.metrics_recorded {
// Skip metrics for requests that never reached the backend (dispatch
// failure before direct() returned Ok). Recording here would emit
// spurious requests_total increments and OSL-histogram zeros.
if self.metrics_recorded || !self.dispatched {
return;
}
self.metrics_recorded = true;
......@@ -254,51 +222,34 @@ impl Drop for RequestGuard {
fn drop(&mut self) {
self.record_metrics();
spawn_cleanup_task(
&self.chooser,
!self.freed && self.scheduler_tracked,
&self.context_id,
self.deferred_close.take(),
"drop guard",
);
}
}
let deferred_close = self.deferred_close.take();
let needs_free = !self.freed && self.scheduler_tracked;
impl PendingDispatchGuard {
fn new(
chooser: Arc<KvRouter>,
scheduler_tracked: bool,
context_id: String,
deferred_close: Option<SessionCloseAction>,
) -> Self {
Self {
chooser,
scheduler_tracked,
context_id,
deferred_close,
disarmed: false,
if deferred_close.is_none() && !needs_free {
return;
}
}
fn disarm(mut self) -> Option<SessionCloseAction> {
self.disarmed = true;
self.deferred_close.take()
}
}
impl Drop for PendingDispatchGuard {
fn drop(&mut self) {
if self.disarmed {
let Ok(handle) = tokio::runtime::Handle::try_current() else {
tracing::warn!(
"No tokio runtime for drop guard cleanup of request {}",
self.context_id
);
return;
}
};
spawn_cleanup_task(
&self.chooser,
self.scheduler_tracked,
&self.context_id,
self.deferred_close.take(),
"dispatch guard",
);
// Mirror finish(): free the scheduler slot first, then fire the
// deferred session close so the worker's KV isn't released while
// generation teardown is still in progress.
let chooser = self.chooser.clone();
let context_id = self.context_id.clone();
handle.spawn(async move {
if needs_free && let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id} (drop guard): {e}");
}
if let Some(close) = deferred_close {
close.execute(&context_id);
}
});
}
}
......@@ -683,30 +634,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
let chooser = self.chooser.clone();
let dispatch_guard = PendingDispatchGuard::new(
chooser.clone(),
scheduler_tracked,
context_id.clone(),
deferred_close,
);
let mut response_stream = self
.inner
.direct(updated_request, instance_id)
.instrument(tracing::info_span!(
"kv_router.route_request",
request_id = %context_id,
worker_id = instance_id,
dp_rank = ?backend_dp_rank,
overlap_blocks = ?overlap_amount,
phase = ?phase,
))
.await?;
let deferred_close = dispatch_guard.disarm();
let stream_context = response_stream.context();
let context_for_monitoring = stream_context.clone();
// Build the guard before returning the stream so a drop-before-first-poll
// still frees booked scheduler state.
let guard = RequestGuard {
// Build the guard BEFORE calling direct() so that its Drop covers the
// error path as well as the drop-before-first-poll path.
//
// Without this, if direct().await? below returns Err, both the
// scheduler slot (booked by find_best_match with update_states=true)
// and the SessionCloseAction (obtained above via on_routed) are leaked:
// SessionCloseAction has no Drop impl, so dropping it never sends the
// close_session RPC; chooser.free() is only called via RequestGuard::Drop.
//
// All guard fields are available here (deferred_close was just obtained;
// isl_tokens/block_size/tracker were set before request.into_parts()).
let mut guard = RequestGuard {
chooser: chooser.clone(),
scheduler_tracked,
context_id: context_id.clone(),
......@@ -723,9 +663,32 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
block_size,
expected_output_tokens,
deferred_close,
dispatched: false,
};
let mut response_stream = self
.inner
.direct(updated_request, instance_id)
.instrument(tracing::info_span!(
"kv_router.route_request",
request_id = %context_id,
worker_id = instance_id,
dp_rank = ?backend_dp_rank,
overlap_blocks = ?overlap_amount,
phase = ?phase,
))
.await?;
// direct() succeeded — mark dispatched so record_metrics() fires.
// If direct() returned Err above, guard drops here with dispatched=false
// → RequestGuard::Drop fires → chooser.free() + deferred_close.execute()
// but record_metrics() is suppressed (no backend work was done).
guard.dispatched = true;
let stream_context = response_stream.context();
let context_for_monitoring = stream_context.clone();
let wrapped_stream = Box::pin(async_stream::stream! {
// Move guard into the stream closure. Drop fires here if the stream
// is polled to completion, or via the outer Drop if never polled.
let mut guard = guard;
loop {
......
......@@ -75,6 +75,7 @@ xxhash-rust = { workspace = true }
cudarc = { workspace = true, features = ["nvtx"], optional = true }
arc-swap = { version = "1" }
crossbeam-queue = { version = "0.3" }
async-once-cell = { version = "0.5.4" }
bincode = { version = "1" }
console-subscriber = { version = "0.4", optional = true }
......
......@@ -808,6 +808,7 @@ dependencies = [
"blake3",
"bytes",
"chrono",
"crossbeam-queue",
"dashmap",
"derive-getters",
"derive_builder",
......
......@@ -3,6 +3,12 @@
//! TCP Request Plane Client
//!
//! Lock-free, LRU-based shared connection pool with round-robin selection.
//! Connections are Arc-wrapped and shared across concurrent requests.
//! The hot path (per-request) is fully lock-free: ArcSwap + atomic round-robin + SegQueue push.
//! The cold path (connect/prune) uses a Mutex on the LRU cache.
//! Writer tasks batch requests into a reusable BytesMut buffer for a single write_all()
//! syscall per drain, avoiding the fixed-size cap and double-copy of BufWriter.
use super::unified_client::{ClientStats, Headers, RequestPlaneClient};
use crate::metrics::transport_metrics::{
......@@ -10,36 +16,61 @@ use crate::metrics::transport_metrics::{
};
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use crossbeam_queue::SegQueue;
use dashmap::DashMap;
use futures::StreamExt;
use std::io::IoSlice;
use lru::LruCache;
use std::net::SocketAddr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, mpsc, oneshot};
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio_util::codec::FramedRead;
/// Default timeout for TCP request acknowledgment
const DEFAULT_TCP_REQUEST_TIMEOUT_SECS: u64 = 5;
/// Default connection pool size per host
/// Default connection pool size per host.
/// Ceiling: DEFAULT_POOL_SIZE(100) x REQUEST_CHANNEL_BUFFER(1024) = 102,400 concurrent
/// slots per host. At sub-10ms transport latency only a handful of connections are
/// ever active simultaneously; the remainder are opened on-demand by should_grow().
const DEFAULT_POOL_SIZE: usize = 100;
/// Buffer size for request channel per connection (backpressure control)
const REQUEST_CHANNEL_BUFFER: usize = 50;
/// Admission semaphore permits (and pipelining depth) per connection.
/// Raised from 256 to 1024 for high-throughput frontends (1M+ RPS across ~100 backends).
/// At 1ms round-trip a single connection already supports ~1,000 concurrent requests,
/// so deeper pipelining avoids unnecessary connection proliferation and lets the
/// the writer task drain larger batches per write_all (fewer syscalls at high rate).
/// Head-of-line blocking stays acceptable because the TCP transport layer targets
/// sub-ms latency; later requests rarely wait long behind earlier ones.
/// Per-host ceiling: DEFAULT_POOL_SIZE(100) x REQUEST_CHANNEL_BUFFER(1024) = 102,400.
const REQUEST_CHANNEL_BUFFER: usize = 1024;
/// Pre-allocated read buffer size (64KB typical message size)
const READ_BUFFER_SIZE: usize = 65536;
/// Maximum retries when another task is connecting (prevents unbounded recursion)
const MAX_CONNECT_RETRIES: usize = 5;
/// Default global connect concurrency limit across all hosts
const DEFAULT_GLOBAL_CONNECT_LIMIT: usize = 64;
/// Default idle host TTL in seconds before cleanup
const DEFAULT_HOST_IDLE_TTL_SECS: u64 = 300;
/// Default maximum message size for TCP client (32 MB)
/// This is the limit for a SINGLE message. For larger data, split into multiple messages.
const DEFAULT_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
/// Spin loop limit before falling back to async Notify in writer task
const WRITER_SPIN_LIMIT: u32 = 64;
/// Initial capacity of the per-writer BytesMut send buffer (256 KB).
/// The buffer grows automatically beyond this if a batch exceeds it, then
/// stays at the high-water mark for subsequent batches (amortised zero allocation).
const WRITER_INITIAL_BUF_CAPACITY: usize = 256 * 1024;
/// Get maximum message size from environment or use default
fn get_max_message_size() -> usize {
std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
......@@ -48,6 +79,13 @@ fn get_max_message_size() -> usize {
.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE)
}
/// Check if latency tracing is enabled via environment
fn latency_trace_enabled() -> bool {
std::env::var("DYN_TCP_LATENCY_TRACE")
.ok()
.is_some_and(|v| v == "1" || v == "true")
}
/// TCP request plane configuration
#[derive(Debug, Clone)]
pub struct TcpRequestConfig {
......@@ -105,40 +143,112 @@ impl TcpRequestConfig {
}
}
/// Request to be sent over TCP
/// Pre-encoded on caller's thread for optimal write performance (hot path optimization)
struct TcpRequest {
/// Pending request in the lock-free submit queue
struct PendingRequest {
/// Pre-encoded request data ready to send (zero-copy Bytes)
/// Encoding happens on caller thread to parallelize across multiple request handlers
encoded_data: Bytes,
/// Oneshot channel to send response back to caller
response_tx: oneshot::Sender<Result<Bytes>>,
}
/// TCP connection with split read/write tasks
/// RAII guard that decrements the inflight counter on drop.
///
/// Guarantees `fetch_sub` runs on ALL exit paths from `send_request`:
/// - Normal return (success or `?` propagation)
/// - `tokio::time::timeout` cancellation (future dropped mid-await)
/// - Any other future drop (e.g., `select!` branch cancellation)
///
/// Without this guard, a timeout drops the future at `response_rx.await`
/// and the `fetch_sub` below it is never reached, permanently inflating
/// the inflight counter and poisoning `available_capacity()`.
struct InflightGuard(Arc<AtomicU64>);
impl Drop for InflightGuard {
fn drop(&mut self) {
// Release: pairs with the Acquire loads in available_capacity() and
// cleanup_idle_hosts() so the decrement is visible to any reader that
// subsequently observes the updated counter value.
self.0.fetch_sub(1, Ordering::Release);
}
}
/// RAII guard that resets the `connecting` CAS gate on drop.
///
/// Design: One writer task + one reader task per connection
/// - Writer task receives pre-encoded requests and writes directly (hot path optimized)
/// - Reader task uses framed codec for robust protocol handling
/// - FIFO ordering ensures request/response correlation without explicit IDs
/// After winning the CAS (`connecting` set to `true`), two subsequent
/// await points in `ensure_capacity_or_heal` are cancellation-unsafe:
///
/// 1. `connect_limiter.acquire().await` — if the enclosing Tokio future is
/// dropped here, `connecting` stays `true` forever and the existing
/// `map_err` closure only runs when the `Semaphore` is *closed*, not on
/// cancellation.
/// 2. `TcpConnection::connect(...).await` — if cancelled here, the explicit
/// `self.connecting.store(false)` below the await is never reached.
///
/// In both cases callers that lost the CAS race will block on
/// `connect_notify.notified()` until their `connect_timeout` expires, then
/// retry the CAS — but `connecting` is still `true`, so they time out again,
/// permanently stalling pool growth for the affected host.
///
/// Fix: construct this guard immediately after the CAS succeeds. Its Drop
/// unconditionally resets `connecting` and wakes waiters, covering every exit
/// path: normal return (Ok/Err), `?` propagation, and future cancellation.
/// Double-reset (store false when already false) and double-notify (wake an
/// empty waiter set) are both no-ops, so explicit cleanup calls in the success
/// and error branches are left in place for readability without any risk.
struct ConnectingGuard<'a> {
connecting: &'a AtomicBool,
notify: &'a tokio::sync::Notify,
}
impl Drop for ConnectingGuard<'_> {
fn drop(&mut self) {
self.connecting.store(false, Ordering::Release);
self.notify.notify_waiters();
}
}
/// TCP connection with lock-free submit and batched write/read tasks.
///
/// Performance: Hybrid approach optimizes each path independently:
/// - Write path: Pre-encode on caller thread → direct write (minimal overhead, parallel encoding)
/// - Read path: Framed codec handles partial reads and protocol complexity automatically
/// Design: SegQueue submit → batched writer task → reader task → oneshot response
/// - Callers push to SegQueue (lock-free, ~20-40ns)
/// - Writer task drains queue into a reusable BytesMut, single write_all per batch
/// - Reader task uses framed codec, pops response_tx from SegQueue
/// - FIFO ordering: writer pushes ALL response_txs AFTER write_all succeeds
struct TcpConnection {
addr: SocketAddr,
/// Channel to send requests to the writer task
request_tx: mpsc::Sender<TcpRequest>,
/// Lock-free queue for callers to submit requests
submit_queue: Arc<SegQueue<PendingRequest>>,
/// Lock-free queue for writer→reader response_tx handoff
response_queue: Arc<SegQueue<oneshot::Sender<Result<Bytes>>>>,
/// Notify to wake writer when submit_queue transitions from empty
writer_notify: Arc<tokio::sync::Notify>,
/// Writer task handle for cleanup
writer_handle: Arc<JoinHandle<()>>,
/// Reader task handle for cleanup
reader_handle: Arc<JoinHandle<()>>,
/// Health status (false if tasks have failed)
healthy: Arc<AtomicBool>,
/// Closed for new submissions once the writer begins its terminal drain path.
/// This closes the race where a request can enqueue after the final drain pass
/// and then wait until the outer request timeout.
closed: Arc<AtomicBool>,
/// Number of in-flight requests (for capacity heuristic)
inflight: Arc<AtomicU64>,
/// Max inflight for capacity heuristic (matches channel_buffer)
channel_buffer: usize,
/// Bounded admission semaphore (permits == channel_buffer).
/// Callers must acquire a permit before pushing to submit_queue,
/// enforcing channel_buffer as a hard limit and preventing unbounded
/// SegQueue growth and heap OOM under sustained overload.
/// Permit is held for the lifetime of the call and released on drop
/// (success, error via `?`, or timeout/cancel future drop).
admission: Arc<tokio::sync::Semaphore>,
#[cfg(test)]
post_enqueue_barrier: Option<Arc<tokio::sync::Barrier>>,
}
impl TcpConnection {
/// Create a new connection with split read/write tasks
/// Create a new connection with lock-free submit and batched write/read tasks
async fn connect(addr: SocketAddr, timeout: Duration, channel_buffer: usize) -> Result<Self> {
let stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
.await
......@@ -149,50 +259,83 @@ impl TcpConnection {
let (read_half, write_half) = tokio::io::split(stream);
// Channel for writer task to receive requests
let (request_tx, request_rx) = mpsc::channel::<TcpRequest>(channel_buffer);
// Channel for writer to forward response channels to reader (FIFO correlation)
let (response_tx_channel, response_rx_channel) =
mpsc::unbounded_channel::<oneshot::Sender<Result<Bytes>>>();
let submit_queue = Arc::new(SegQueue::new());
let response_queue = Arc::new(SegQueue::new());
let writer_notify = Arc::new(tokio::sync::Notify::new());
let healthy = Arc::new(AtomicBool::new(true));
let closed = Arc::new(AtomicBool::new(false));
let inflight = Arc::new(AtomicU64::new(0));
let admission = Arc::new(tokio::sync::Semaphore::new(channel_buffer));
// Spawn writer task
// Spawn writer task (batches into BytesMut, single write_all per drain)
let writer_handle = {
let submit_q = submit_queue.clone();
let response_q = response_queue.clone();
let notify = writer_notify.clone();
let healthy = healthy.clone();
let closed = closed.clone();
let admission = admission.clone();
tokio::spawn(async move {
if let Err(e) = Self::writer_task(write_half, request_rx, response_tx_channel).await
if let Err(e) = Self::writer_task(
write_half,
submit_q,
response_q,
notify,
healthy.clone(),
closed.clone(),
)
.await
{
tracing::debug!("Writer task failed for {}: {}", addr, e);
// healthy and closed are already set inside writer_task before
// drain_pending; these are idempotent backstops.
healthy.store(false, Ordering::Relaxed);
closed.store(true, Ordering::Release);
// Unblock any callers currently waiting to acquire a permit
// so they fail fast via the post-acquire health re-check.
admission.close();
}
})
};
// Spawn reader task
// Spawn reader task (passes writer_notify so reader can wake writer on exit)
let reader_handle = {
let response_q = response_queue.clone();
let healthy = healthy.clone();
let writer_notify = writer_notify.clone();
let admission = admission.clone();
tokio::spawn(async move {
if let Err(e) = Self::reader_task(read_half, response_rx_channel).await {
if let Err(e) =
Self::reader_task(read_half, response_q, healthy.clone(), writer_notify).await
{
tracing::debug!("Reader task failed for {}: {}", addr, e);
healthy.store(false, Ordering::Relaxed);
// Unblock any callers currently waiting to acquire a permit.
admission.close();
}
})
};
Ok(Self {
addr,
request_tx,
submit_queue,
response_queue,
writer_notify,
writer_handle: Arc::new(writer_handle),
reader_handle: Arc::new(reader_handle),
healthy,
closed,
inflight,
channel_buffer,
admission,
#[cfg(test)]
post_enqueue_barrier: None,
})
}
/// Configure socket for ultra-low latency based on dyn-transports patterns
fn configure_socket(stream: &TcpStream) -> Result<()> {
use socket2::{SockRef, Socket};
use socket2::SockRef;
let sock_ref = SockRef::from(stream);
......@@ -238,44 +381,215 @@ impl TcpConnection {
Ok(())
}
/// Writer task: receives pre-encoded requests and writes directly to socket
/// Drain the submit queue, sending errors on all oneshot senders.
/// Called when the writer task exits to prevent orphaned callers waiting
/// forever on requests that were never processed.
///
/// NOTE: response_queue is intentionally NOT drained here. Items in
/// response_queue are for requests already flushed to the wire -- the
/// reader may still deliver their responses. If the reader also exits,
/// the remaining oneshot senders are dropped when the TcpConnection is
/// cleaned up, and callers receive a RecvError caught by the outer timeout.
fn drain_pending(submit_queue: &SegQueue<PendingRequest>) {
while let Some(req) = submit_queue.pop() {
let _ = req
.response_tx
.send(Err(anyhow::anyhow!("Connection closed")));
}
}
/// Drain committed response waiters when the reader determines that no more
/// responses can ever arrive on this connection.
fn drain_response_waiters(
response_queue: &SegQueue<oneshot::Sender<Result<Bytes>>>,
err_msg: impl Into<String>,
) {
let err_msg = err_msg.into();
while let Some(tx) = response_queue.pop() {
let _ = tx.send(Err(anyhow::anyhow!("{}", err_msg)));
}
}
/// Writer task: drains SegQueue into a reusable BytesMut, then issues a single
/// write_all() per drain cycle — one syscall regardless of batch size.
///
/// Why BytesMut instead of BufWriter:
/// - BufWriter has a fixed internal cap (256 KB); batches larger than that trigger
/// implicit mid-batch partial flushes, breaking the one-syscall-per-batch guarantee.
/// With channel_buffer=1024 this happens routinely under moderate load.
/// - BufWriter copies each Bytes into its internal Vec<u8> and then the kernel
/// copies again on flush — two copies per request. BytesMut collapses this to
/// one extend_from_slice + one write_all (single kernel copy).
/// - BytesMut grows to the batch HWM and stays there; after warm-up there are
/// zero allocations per batch.
///
/// Performance optimization: Pre-encoding happens on caller's thread to enable
/// parallel encoding across multiple request handlers, while this task focuses
/// on sequential socket writes with minimal overhead.
/// Flush-boundary tracking: response_txs are held locally during the write
/// phase and only pushed to response_queue AFTER write_all succeeds. This way:
/// - On write error, callers in the current batch get immediate errors
/// (not "Connection closed" via drain_pending)
/// - Previously written batches stay in response_queue for the reader to
/// deliver -- they are NOT erroneously killed by drain_pending
/// - The server cannot respond before write_all returns, so the reader will
/// never see a response before its response_tx is in the queue
async fn writer_task(
mut write_half: tokio::io::WriteHalf<TcpStream>,
mut request_rx: mpsc::Receiver<TcpRequest>,
response_tx_channel: mpsc::UnboundedSender<oneshot::Sender<Result<Bytes>>>,
submit_queue: Arc<SegQueue<PendingRequest>>,
response_queue: Arc<SegQueue<oneshot::Sender<Result<Bytes>>>>,
notify: Arc<tokio::sync::Notify>,
healthy: Arc<AtomicBool>,
closed: Arc<AtomicBool>,
) -> Result<()> {
while let Some(req) = request_rx.recv().await {
// Direct write of pre-encoded data (hot path)
// With TCP_NODELAY, no need for explicit flush()
match write_half.write_all(&req.encoded_data).await {
Ok(()) => {
TCP_BYTES_SENT_TOTAL.inc_by(req.encoded_data.len() as f64);
// Forward response channel to reader task (FIFO ordering)
if response_tx_channel.send(req.response_tx).is_err() {
tracing::debug!("Reader task closed, stopping writer");
let mut send_buf = BytesMut::with_capacity(WRITER_INITIAL_BUF_CAPACITY);
// Hoisted outside the loop to reuse allocations across drain cycles.
let mut encoded_batch: Vec<Bytes> = Vec::with_capacity(64);
let mut response_batch: Vec<oneshot::Sender<Result<Bytes>>> = Vec::with_capacity(64);
let trace = latency_trace_enabled();
// Latency instrumentation accumulators
let mut batch_count: u64 = 0;
let mut total_batch_size: u64 = 0;
let mut total_batch_write_ns: u64 = 0;
let mut last_report = std::time::Instant::now();
let result: Result<()> = async {
loop {
// Adaptive spin: try queue before falling back to async Notify
let mut spins: u32 = 0;
while submit_queue.is_empty() {
// Check if reader has exited
if !healthy.load(Ordering::Relaxed) {
return Err(anyhow::anyhow!("Reader exited, writer stopping"));
}
spins += 1;
if spins >= WRITER_SPIN_LIMIT {
notify.notified().await;
break;
}
std::hint::spin_loop();
}
Err(e) => {
// Write failed, notify caller and stop
// Drain all available requests (reuse pre-allocated Vecs)
encoded_batch.clear();
response_batch.clear();
while let Some(req) = submit_queue.pop() {
encoded_batch.push(req.encoded_data);
response_batch.push(req.response_tx);
}
let count = encoded_batch.len();
if count == 0 {
continue; // spurious wakeup
}
// Phase 1: Gather all encoded payloads into the send buffer.
// A single extend_from_slice per item — no intermediate BufWriter
// copy, no implicit partial flushes if the batch exceeds a cap.
// response_txs stay local — they are NOT in response_queue yet.
let write_start = if trace {
Some(std::time::Instant::now())
} else {
None
};
for data in &encoded_batch {
send_buf.extend_from_slice(data);
}
// Phase 2: Single write_all = one syscall for the entire batch.
// The socket send buffer is 2 MB; batches that fit go out in one
// writev(). Larger batches loop inside write_all but still hit
// the kernel only as fast as it drains the socket buffer.
if let Err(e) = write_half.write_all(&send_buf).await {
// Data may be partially on the wire — the connection is in an
// unrecoverable state (broken framing). Fail the entire batch,
// clear the buffer defensively so stale data can never be
// re-sent if reconnect-and-retry is ever added, then exit.
send_buf.clear();
let err_msg = format!("Write failed: {}", e);
let _ = req.response_tx.send(Err(anyhow::anyhow!("{}", err_msg)));
return Err(anyhow::anyhow!("{}", err_msg));
for tx in response_batch.drain(..) {
let _ = tx.send(Err(anyhow::anyhow!("{}", err_msg)));
}
return Err(e.into());
}
TCP_BYTES_SENT_TOTAL.inc_by(send_buf.len() as f64);
send_buf.clear(); // reset length, keep allocation for next batch
// Phase 3: write_all succeeded — data is committed to the wire.
// NOW push response_txs to response_queue so the reader can
// match them with incoming responses.
for tx in response_batch.drain(..) {
response_queue.push(tx);
}
// Check if reader has exited (e.g., peer closed connection).
// Kernel buffering may let writes succeed after peer close,
// but the reader won't deliver responses — exit now.
if !healthy.load(Ordering::Relaxed) {
return Err(anyhow::anyhow!("Reader exited, writer stopping"));
}
// Latency instrumentation
if trace {
if let Some(start) = write_start {
total_batch_write_ns += start.elapsed().as_nanos() as u64;
}
batch_count += 1;
total_batch_size += count as u64;
if last_report.elapsed() >= Duration::from_secs(5) {
let avg_batch = if batch_count > 0 {
total_batch_size / batch_count
} else {
0
};
let avg_write_ns = if batch_count > 0 {
total_batch_write_ns / batch_count
} else {
0
};
tracing::info!(
batches = batch_count,
avg_batch_size = avg_batch,
avg_batch_write_ns = avg_write_ns,
"TCP writer instrumentation summary"
);
batch_count = 0;
total_batch_size = 0;
total_batch_write_ns = 0;
last_report = std::time::Instant::now();
}
}
encoded_batch.clear();
}
}
Ok(())
.await;
// On exit, drain only the submit_queue (unprocessed requests).
// response_queue items are for committed (flushed) data — the reader
// may still deliver their responses. If it can't, the oneshot senders
// drop when TcpConnection is cleaned up and callers get RecvError.
healthy.store(false, Ordering::Relaxed);
// Signal closed BEFORE drain so any concurrent sender that enqueues
// between this drain and the end of the function will see closed=true
// in its post-enqueue double-check and call drain_pending itself.
// Without this, the window between drain_pending and the error handler's
// closed.store(true) in the spawned wrapper leaves late enqueues stranded.
closed.store(true, Ordering::Release);
Self::drain_pending(&submit_queue);
result
}
/// Reader task: reads responses using framed codec and sends them back via oneshot channels
/// Protocol framing handled automatically via TcpResponseCodec
/// Reader task: reads responses using framed codec, pops response_tx from SegQueue.
///
/// On exit (clean close or error), sets `healthy=false` and wakes the writer
/// via `writer_notify` so it can detect reader death and drain pending callers.
async fn reader_task(
read_half: tokio::io::ReadHalf<TcpStream>,
mut response_rx_channel: mpsc::UnboundedReceiver<oneshot::Sender<Result<Bytes>>>,
response_queue: Arc<SegQueue<oneshot::Sender<Result<Bytes>>>>,
healthy: Arc<AtomicBool>,
writer_notify: Arc<tokio::sync::Notify>,
) -> Result<()> {
use crate::pipeline::network::codec::TcpResponseCodec;
......@@ -283,147 +597,689 @@ impl TcpConnection {
let codec = TcpResponseCodec::new(Some(max_message_size));
let mut framed = FramedRead::new(read_half, codec);
while let Some(response_tx) = response_rx_channel.recv().await {
// Read the next response message from the framed stream
// The codec handles all protocol framing and size checks automatically
match framed.next().await {
Some(Ok(response_msg)) => {
let _ = response_tx.send(Ok(response_msg.data));
while let Some(result) = framed.next().await {
// Spin briefly if response_queue is empty (writer hasn't pushed yet)
let tx = loop {
if let Some(tx) = response_queue.pop() {
break tx;
}
Some(Err(e)) => {
let err = anyhow::anyhow!("Failed to decode response: {}", e);
let _ = response_tx.send(Err(err));
return Err(anyhow::anyhow!("Failed to decode response"));
// If the connection is already unhealthy (writer failed), stop spinning
if !healthy.load(Ordering::Relaxed) {
return Err(anyhow::anyhow!("Connection unhealthy, reader exiting"));
}
tokio::task::yield_now().await;
};
match result {
Ok(response_msg) => {
let _ = tx.send(Ok(response_msg.data));
}
None => {
let err = anyhow::anyhow!("Connection closed by peer");
let _ = response_tx.send(Err(err));
return Err(anyhow::anyhow!("Connection closed"));
Err(e) => {
let err_msg = format!("Failed to decode response: {}", e);
let _ = tx.send(Err(anyhow::anyhow!("{}", err_msg)));
healthy.store(false, Ordering::Relaxed);
Self::drain_response_waiters(
&response_queue,
format!("Connection closed after decode failure: {}", e),
);
// Wake writer so it can detect unhealthy and exit
writer_notify.notify_one();
return Err(anyhow::anyhow!("Failed to decode response"));
}
}
}
// Connection closed by peer — mark unhealthy so the writer task
// detects reader death and drains pending callers.
healthy.store(false, Ordering::Relaxed);
Self::drain_response_waiters(
&response_queue,
"Connection closed before response was received",
);
// Wake writer from Notify.await so it can check healthy and exit
writer_notify.notify_one();
Ok(())
}
/// Send a request and wait for response
///
/// Performance: Encoding happens on caller's thread (hot path optimization)
/// to enable parallel encoding across multiple request handlers. The writer
/// task then performs sequential writes with minimal overhead.
/// Send a request via lock-free SegQueue push (~20-40ns)
async fn send_request(&self, payload: Bytes, headers: &Headers) -> Result<Bytes> {
use crate::pipeline::network::codec::TcpRequestMessage;
// Check health before sending
if !self.healthy.load(Ordering::Relaxed) {
anyhow::bail!("Connection unhealthy (tasks failed)");
}
if self.closed.load(Ordering::Acquire) {
anyhow::bail!("Connection closed (writer exited)");
}
// Extract endpoint path from headers (required for routing)
let endpoint_path = headers
.get("x-endpoint-path")
.ok_or_else(|| anyhow::anyhow!("Missing x-endpoint-path header for TCP request"))?
.to_string();
// Encode request on caller's thread (hot path optimization)
// This allows multiple concurrent callers to encode in parallel
// rather than serializing through the writer task
// Include all headers (especially trace headers) in the message
let trace = latency_trace_enabled();
let e2e_start = if trace {
Some(std::time::Instant::now())
} else {
None
};
// Bounded admission: block until a slot is free (channel_buffer hard limit).
// The permit is held for the duration of this call and released on drop,
// whether the caller returns normally, errors out, or the enclosing
// tokio::time::timeout drops this future mid-flight.
// This prevents unbounded SegQueue growth and heap OOM under overload.
// encode() runs AFTER acquire so callers blocked on the semaphore do not
// hold a pre-allocated encoded frame, bounding peak memory to
// channel_buffer * frame_size per connection.
let _permit = self
.admission
.acquire()
.await
.map_err(|_| anyhow::anyhow!("Connection closed (admission gate shut)"))?;
// encode() called here — after admission is granted — so the frame is
// only allocated once we have capacity to process it.
let request_msg = TcpRequestMessage::with_headers(endpoint_path, headers.clone(), payload);
let encoded_data = request_msg.encode()?;
// Create response channel
let (response_tx, response_rx) = oneshot::channel();
// Send to writer task (bounded channel provides backpressure)
let req = TcpRequest {
// Re-check health: the connection may have gone unhealthy while we
// were waiting for a permit. Fail fast so the caller can retry on a
// fresh connection rather than pushing to a dead submit_queue.
if !self.healthy.load(Ordering::Relaxed) {
anyhow::bail!("Connection unhealthy (tasks failed)");
}
if self.closed.load(Ordering::Acquire) {
anyhow::bail!("Connection closed (writer exited)");
}
// Increment inflight and attach a RAII guard that decrements it on ALL
// exit paths: normal return, `?` propagation, tokio::time::timeout
// cancellation, or any other future drop. Without the guard a timeout
// drops the future at response_rx.await and fetch_sub is never reached,
// permanently inflating the counter and poisoning available_capacity().
// Release: symmetric with the Release in InflightGuard::drop so that
// Acquire loads in available_capacity() see a consistent value.
self.inflight.fetch_add(1, Ordering::Release);
let _inflight_guard = InflightGuard(self.inflight.clone());
// Lock-free submit: ~20-40ns
self.submit_queue.push(PendingRequest {
encoded_data,
response_tx,
};
});
self.request_tx
.send(req)
.await
.map_err(|_| anyhow::anyhow!("Writer task closed"))?;
#[cfg(test)]
if let Some(barrier) = &self.post_enqueue_barrier {
barrier.wait().await;
barrier.wait().await;
}
// Wait for response from reader task
response_rx
// If the writer has already entered terminal cleanup, fail the just-enqueued request
// promptly instead of waiting for the outer request timeout.
if self.closed.load(Ordering::Acquire) {
Self::drain_pending(&self.submit_queue);
} else {
// Wake writer if it was sleeping
self.writer_notify.notify_one();
// The writer may close between the first check and notify. Drain once more so
// late enqueues do not get stranded past the final writer drain pass.
if self.closed.load(Ordering::Acquire) {
Self::drain_pending(&self.submit_queue);
}
}
// Await response. On timeout the enclosing future is dropped here:
// `_inflight_guard` drops → fetch_sub(Release) runs automatically.
// `_permit` drops → semaphore slot is released automatically.
let result = response_rx
.await
.map_err(|_| anyhow::anyhow!("Reader task closed"))?
.map_err(|_| anyhow::anyhow!("Reader task closed"))?;
if trace && let Some(start) = e2e_start {
let e2e_ns = start.elapsed().as_nanos() as u64;
tracing::trace!(e2e_ns = e2e_ns, "TCP request e2e latency");
}
result
}
/// Check if connection is healthy
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Relaxed)
}
/// Available capacity (advisory, for cold path growth heuristic)
fn available_capacity(&self) -> usize {
// Acquire: pairs with Release in fetch_add/InflightGuard::drop so we
// observe the most recently committed inflight value from any thread.
let inflight = self.inflight.load(Ordering::Acquire) as usize;
self.channel_buffer.saturating_sub(inflight)
}
}
/// Connection pool with health checking for TCP connections
struct TcpConnectionPool {
pools: DashMap<SocketAddr, Arc<Mutex<Vec<TcpConnection>>>>,
config: TcpRequestConfig,
/// Per-host connection pool with LRU lifecycle and ArcSwap-based snapshot.
///
/// Hot path: `ArcSwap::load()` + atomic round-robin (~40ns total, fully lock-free).
/// Cold path: LRU prune/insert + ArcSwap store (only on startup or failure).
struct HostPool {
/// Lock-free snapshot for the hot path (rebuilt on cold path changes)
snapshot: arc_swap::ArcSwap<Vec<Arc<TcpConnection>>>,
/// LRU cache for lifecycle management (eviction, pruning)
lru: parking_lot::Mutex<LruCache<u64, Arc<TcpConnection>>>,
/// Atomic round-robin counter for connection selection
counter: AtomicU64,
/// Monotonic ID generator for LRU keys
next_id: AtomicU64,
/// CAS gate to prevent thundering-herd connect storms
connecting: AtomicBool,
/// Wake CAS losers when a connect attempt completes (success or failure)
connect_notify: tokio::sync::Notify,
/// Maximum connections for this host
max_connections: usize,
/// Target address
addr: SocketAddr,
/// Connect timeout
connect_timeout: Duration,
/// Channel buffer size for new connections
channel_buffer: usize,
/// Timestamp of last use (unix millis) for idle cleanup
last_used_ms: AtomicU64,
}
impl TcpConnectionPool {
fn new(config: TcpRequestConfig) -> Self {
impl HostPool {
fn new(addr: SocketAddr, config: &TcpRequestConfig) -> Self {
let cap = NonZeroUsize::new(config.pool_size).unwrap_or(NonZeroUsize::new(1).unwrap());
Self {
pools: DashMap::new(),
config,
snapshot: arc_swap::ArcSwap::from_pointee(Vec::new()),
lru: parking_lot::Mutex::new(LruCache::new(cap)),
counter: AtomicU64::new(0),
next_id: AtomicU64::new(0),
connecting: AtomicBool::new(false),
connect_notify: tokio::sync::Notify::new(),
max_connections: config.pool_size,
addr,
connect_timeout: config.connect_timeout,
channel_buffer: config.channel_buffer,
last_used_ms: AtomicU64::new(current_time_ms()),
}
}
/// Get a connection from the pool or create a new one
/// Automatically filters out unhealthy connections
async fn get_connection(&self, addr: SocketAddr) -> Result<TcpConnection> {
// Try to get from pool (lock-free read with DashMap)
if let Some(pool) = self.pools.get(&addr) {
let mut pool = pool.lock().await;
/// Get a connection, using the hot path (ArcSwap load + atomic RR) when possible.
async fn get_connection(
&self,
connect_limiter: &tokio::sync::Semaphore,
) -> Result<Arc<TcpConnection>> {
// === HOT PATH: ArcSwap load + atomic round-robin (fully lock-free) ===
{
let guard = self.snapshot.load();
let conns = &**guard;
let len = conns.len();
if len > 0 {
let start = self.counter.fetch_add(1, Ordering::Relaxed) as usize;
// First pass: find a healthy connection with available capacity
for i in 0..len {
let idx = (start + i) % len;
let conn = &conns[idx];
if conn.is_healthy() && conn.available_capacity() > 0 {
return Ok(conn.clone());
}
}
// Try to get a healthy connection, discard unhealthy ones
while let Some(conn) = pool.pop() {
if conn.is_healthy() {
return Ok(conn);
} else {
tracing::debug!("Discarding unhealthy connection for {addr}");
// Connection will be dropped here, cleaning up tasks
// All healthy connections are saturated.
// If pool is at max, return a saturated healthy connection (SegQueue backpressure handles it).
// If pool can still grow, fall through to cold path to add capacity.
if len >= self.max_connections {
for i in 0..len {
let idx = (start + i) % len;
if conns[idx].is_healthy() {
return Ok(conns[idx].clone());
}
}
}
// Fall through: all unhealthy OR all saturated and below max_connections
}
}
// Create new connection with configured channel buffer
tracing::debug!("Creating new TCP connection to {addr}");
TcpConnection::connect(
addr,
self.config.connect_timeout,
self.config.channel_buffer,
// === COLD PATH ===
self.ensure_capacity_or_heal(connect_limiter).await
}
/// Determine if the pool should grow
fn should_grow(healthy: &[Arc<TcpConnection>], max_connections: usize) -> bool {
if healthy.is_empty() {
return true;
}
if healthy.len() >= max_connections {
return false;
}
// Grow when every healthy connection's channel is fully saturated
healthy.iter().all(|c| c.available_capacity() == 0)
}
/// Cold path: prune unhealthy connections, optionally grow, rebuild snapshot.
async fn ensure_capacity_or_heal(
&self,
connect_limiter: &tokio::sync::Semaphore,
) -> Result<Arc<TcpConnection>> {
// --- Phase A: lock LRU, prune, decide, build snapshot, unlock ---
let (need_connect, new_snap) = {
let mut lru = self.lru.lock();
// Prune unhealthy (evicted Arcs stay alive for in-flight holders)
let dead: Vec<u64> = lru
.iter()
.filter(|(_, c)| !c.is_healthy())
.map(|(&k, _)| k)
.collect();
for k in dead {
lru.pop(&k);
}
let snap: Vec<Arc<TcpConnection>> = lru.iter().map(|(_, c)| c.clone()).collect();
let grow = Self::should_grow(&snap, self.max_connections);
(grow, snap)
};
// LRU lock released here
// Atomic snapshot update (no RwLock!)
self.snapshot.store(Arc::new(new_snap.clone()));
// Re-check the snapshot for a usable connection. When need_connect
// is true, require available_capacity() > 0 to avoid returning a
// saturated connection and defeating pool growth. This is consistent
// with the hot-path check and the retry check in Phase B.
{
let guard = self.snapshot.load();
let conns = &**guard;
if !conns.is_empty() {
let start = self.counter.fetch_add(1, Ordering::Relaxed) as usize;
for i in 0..conns.len() {
let idx = (start + i) % conns.len();
let conn = &conns[idx];
if conn.is_healthy() && (!need_connect || conn.available_capacity() > 0) {
return Ok(conn.clone());
}
}
}
}
if !need_connect {
anyhow::bail!(
"No healthy TCP connection to {} and pool at capacity ({})",
self.addr,
self.max_connections
);
}
// --- Phase B: connect (no locks held, CAS gate prevents stampede) ---
// Bounded retry loop instead of recursion
for retry in 0..MAX_CONNECT_RETRIES {
if self
.connecting
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
// Won the CAS gate. Guard resets it (and wakes waiters) on ALL
// exit paths — normal return, ?, and future cancellation.
let _connecting_guard = ConnectingGuard {
connecting: &self.connecting,
notify: &self.connect_notify,
};
// Acquire global connect permit to limit total connect bursts.
// If this await is cancelled, _connecting_guard drops and
// resets the gate — no manual cleanup needed.
let _permit = connect_limiter
.acquire()
.await
.map_err(|_| anyhow::anyhow!("Global connect limiter closed"))?;
let connect_result =
TcpConnection::connect(self.addr, self.connect_timeout, self.channel_buffer)
.await;
self.connecting.store(false, Ordering::Release);
match connect_result {
Ok(stream) => {
let new_conn = Arc::new(stream);
// --- Phase C: lock LRU, insert, rebuild snapshot, unlock ---
{
let mut lru = self.lru.lock();
// Re-prune (may have changed during connect)
let dead: Vec<u64> = lru
.iter()
.filter(|(_, c)| !c.is_healthy())
.map(|(&k, _)| k)
.collect();
for k in dead {
lru.pop(&k);
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
lru.put(id, new_conn.clone());
let snap: Vec<Arc<TcpConnection>> =
lru.iter().map(|(_, c)| c.clone()).collect();
drop(lru);
self.snapshot.store(Arc::new(snap));
}
self.connect_notify.notify_waiters();
return Ok(new_conn);
}
Err(e) => {
self.connect_notify.notify_waiters();
return Err(e);
}
}
}
// Another task is connecting. Wait for it to finish (or timeout).
let _ =
tokio::time::timeout(self.connect_timeout, self.connect_notify.notified()).await;
// Try hot path again after yield
let guard = self.snapshot.load();
let conns = &**guard;
let len = conns.len();
if len > 0 {
let start = self.counter.fetch_add(1, Ordering::Relaxed) as usize;
for i in 0..len {
let idx = (start + i) % len;
if conns[idx].is_healthy() && conns[idx].available_capacity() > 0 {
return Ok(conns[idx].clone());
}
}
// Accept saturated if at max
if len >= self.max_connections {
for i in 0..len {
let idx = (start + i) % len;
if conns[idx].is_healthy() {
return Ok(conns[idx].clone());
}
}
}
}
drop(guard);
tracing::trace!(
"TCP pool connect retry {}/{} for {}",
retry + 1,
MAX_CONNECT_RETRIES,
self.addr
);
}
// All connect attempts failed. Fall back to any healthy connection
// (even if saturated) rather than dropping the request — SegQueue
// backpressure handles overload gracefully.
{
let guard = self.snapshot.load();
let conns = &**guard;
let len = conns.len();
if len > 0 {
let start = self.counter.fetch_add(1, Ordering::Relaxed) as usize;
for i in 0..len {
let idx = (start + i) % len;
if conns[idx].is_healthy() {
return Ok(conns[idx].clone());
}
}
}
}
anyhow::bail!(
"No healthy TCP connection to {} after {} connect retries",
self.addr,
MAX_CONNECT_RETRIES
)
.await
}
}
/// Return a connection to the pool if it's healthy and there's space
async fn return_connection(&self, conn: TcpConnection) {
// Only return healthy connections
if !conn.is_healthy() {
tracing::debug!("Not returning unhealthy connection to pool");
return;
/// Get current time in milliseconds since epoch
fn current_time_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
/// Connection pool with LRU lifecycle and shared Arc connections.
///
/// Uses DashMap for per-host pools and a global Semaphore to limit
/// aggregate connect bursts across many cold hosts.
struct TcpConnectionPool {
hosts: DashMap<SocketAddr, Arc<HostPool>>,
config: TcpRequestConfig,
/// Global connect concurrency limiter (caps aggregate connect bursts)
connect_limiter: Arc<tokio::sync::Semaphore>,
/// Idle host TTL for cleanup
host_idle_ttl_ms: u64,
}
impl TcpConnectionPool {
fn new(config: TcpRequestConfig) -> Self {
let global_limit = std::env::var("DYN_TCP_GLOBAL_CONNECT_LIMIT")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_GLOBAL_CONNECT_LIMIT);
let host_idle_ttl_secs = std::env::var("DYN_TCP_HOST_IDLE_TTL_SECS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(DEFAULT_HOST_IDLE_TTL_SECS);
Self {
hosts: DashMap::new(),
config,
connect_limiter: Arc::new(tokio::sync::Semaphore::new(global_limit)),
host_idle_ttl_ms: host_idle_ttl_secs * 1000,
}
}
let addr = conn.addr;
/// Get a connection from the pool or create a new one.
/// Hot path: DashMap shard read lock → ArcSwap load → atomic RR.
async fn get_connection(&self, addr: SocketAddr) -> Result<Arc<TcpConnection>> {
// Fast: clone the Arc out of the DashMap guard so the shard lock is
// released before any `.await` (DashMap guards are !Send and holding
// them across await points blocks other shard operations).
if let Some(host) = self.hosts.get(&addr).map(|entry| Arc::clone(&*entry)) {
host.last_used_ms
.store(current_time_ms(), Ordering::Relaxed);
return host.get_connection(&self.connect_limiter).await;
}
// Get or create pool for this address (lock-free with DashMap)
let pool = self
.pools
// Slow: first request to this host
let host = self
.hosts
.entry(addr)
.or_insert_with(|| Arc::new(Mutex::new(Vec::new())))
.or_insert_with(|| Arc::new(HostPool::new(addr, &self.config)))
.clone();
let mut pool = pool.lock().await;
if pool.len() < self.config.pool_size {
pool.push(conn);
} else {
tracing::debug!("Connection pool full for {addr}, dropping connection");
// Otherwise drop the connection (tasks will be cleaned up)
host.last_used_ms
.store(current_time_ms(), Ordering::Relaxed);
host.get_connection(&self.connect_limiter).await
}
/// Eagerly establish one TCP connection to the given address.
///
/// Creates a `HostPool` entry (if absent) and opens a single connection
/// through the normal cold path so the global `connect_limiter` is respected.
/// Failures are logged but never propagated — the lazy cold path remains
/// as fallback.
async fn warmup(&self, addr: SocketAddr) {
let host = self
.hosts
.entry(addr)
.or_insert_with(|| Arc::new(HostPool::new(addr, &self.config)))
.clone();
host.last_used_ms
.store(current_time_ms(), Ordering::Relaxed);
match host.get_connection(&self.connect_limiter).await {
Ok(_) => tracing::debug!("TCP warmup: pre-connected to {}", addr),
Err(e) => tracing::warn!("TCP warmup: failed to pre-connect to {}: {}", addr, e),
}
}
/// Background task that watches the instance discovery channel and
/// eagerly warms one TCP connection for each newly-discovered TCP backend.
///
/// Uses a diff-based approach: tracks a `HashSet<SocketAddr>` of known
/// addresses, only warms truly new ones.
fn start_warmup_watcher(
self: &Arc<Self>,
mut instance_rx: tokio::sync::watch::Receiver<Vec<crate::component::Instance>>,
cancel_token: tokio_util::sync::CancellationToken,
) {
let pool = Arc::clone(self);
tokio::spawn(async move {
let mut known_addrs = std::collections::HashSet::<SocketAddr>::new();
// Seed from current value so we don't re-warm existing backends
{
let instances = instance_rx.borrow_and_update();
for inst in instances.iter() {
if let crate::component::TransportType::Tcp(ref addr_str) = inst.transport
&& let Ok((sock, _)) = TcpRequestClient::parse_address(addr_str)
{
known_addrs.insert(sock);
}
}
}
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::debug!("TCP warmup watcher cancelled");
break;
}
result = instance_rx.changed() => {
if result.is_err() {
tracing::debug!("TCP warmup watcher: instance channel closed");
break;
}
let instances = instance_rx.borrow_and_update().clone();
let mut current_addrs = std::collections::HashSet::<SocketAddr>::new();
for inst in &instances {
if let crate::component::TransportType::Tcp(ref addr_str) = inst.transport
&& let Ok((sock, _)) = TcpRequestClient::parse_address(addr_str)
{
current_addrs.insert(sock);
if !known_addrs.contains(&sock) {
let pool = Arc::clone(&pool);
tokio::spawn(async move {
pool.warmup(sock).await;
});
}
}
}
known_addrs = current_addrs;
}
}
}
});
}
/// Opportunistic cleanup of idle host pools.
/// Called periodically by the background maintenance task; not on the hot path.
fn cleanup_idle_hosts(&self) {
let now = current_time_ms();
let ttl = self.host_idle_ttl_ms;
let stale: Vec<SocketAddr> = self
.hosts
.iter()
.filter(|entry| {
let last = entry.value().last_used_ms.load(Ordering::Relaxed);
if now.saturating_sub(last) <= ttl {
return false;
}
// Don't evict hosts that still have in-flight requests --
// a long-running request with no new checkouts would look
// idle but killing it would fail legitimate work.
let snap = entry.value().snapshot.load();
!snap.iter().any(|c| c.inflight.load(Ordering::Acquire) > 0)
})
.map(|entry| *entry.key())
.collect();
for addr in stale {
tracing::debug!("Removing idle TCP host pool for {}", addr);
if let Some((_, host)) = self.hosts.remove(&addr) {
// Shut down connections: mark unhealthy, wake writers so they
// exit and drain pending callers, and abort readers since they
// may be blocked in framed.next().await on a quiet peer and
// would otherwise hold the socket FD indefinitely.
let snap = host.snapshot.load();
for conn in snap.iter() {
conn.healthy.store(false, Ordering::Relaxed);
// Unblock callers waiting on the admission semaphore so
// they fail fast via the post-acquire health re-check,
// rather than waiting for drain_pending to free permits.
conn.admission.close();
conn.writer_notify.notify_one();
conn.reader_handle.abort();
}
}
}
}
/// Count active and idle connections across all host pools.
/// Active = healthy with inflight > 0, idle = healthy with inflight == 0.
fn connection_counts(&self) -> (usize, usize) {
let mut active = 0usize;
let mut idle = 0usize;
for entry in self.hosts.iter() {
let snap = entry.value().snapshot.load();
for conn in snap.iter() {
if conn.is_healthy() {
if conn.inflight.load(Ordering::Acquire) > 0 {
active += 1;
} else {
idle += 1;
}
}
}
}
(active, idle)
}
/// Spawn a background task that periodically cleans up idle host pools.
///
/// Uses a `Weak` reference so the task automatically stops when the
/// `TcpConnectionPool` is dropped (no explicit cancellation needed).
/// The cleanup interval is half the idle TTL so hosts are reaped
/// reasonably promptly after expiry.
fn start_idle_cleanup(self: &Arc<Self>) {
// Only spawn when a tokio runtime is available. In production this is
// always the case (clients are created from async contexts), but sync
// unit tests construct TcpRequestClient without a runtime.
let Ok(handle) = tokio::runtime::Handle::try_current() else {
tracing::debug!("No tokio runtime available; idle host cleanup disabled");
return;
};
let pool = Arc::downgrade(self);
// Floor at 30s to prevent busy-loop if DYN_TCP_HOST_IDLE_TTL_SECS=0
let interval = Duration::from_millis((self.host_idle_ttl_ms / 2).max(30_000));
handle.spawn(async move {
loop {
tokio::time::sleep(interval).await;
match pool.upgrade() {
Some(pool) => pool.cleanup_idle_hosts(),
None => break,
}
}
});
}
}
......@@ -450,8 +1306,10 @@ impl TcpRequestClient {
/// Create a new TCP request client with custom configuration
pub fn with_config(config: TcpRequestConfig) -> Result<Self> {
let pool = Arc::new(TcpConnectionPool::new(config.clone()));
pool.start_idle_cleanup();
Ok(Self {
pool: Arc::new(TcpConnectionPool::new(config.clone())),
pool,
config,
stats: Arc::new(TcpClientStats {
requests_sent: AtomicU64::new(0),
......@@ -468,10 +1326,22 @@ impl TcpRequestClient {
Self::with_config(TcpRequestConfig::from_env())
}
/// Start a background task that eagerly warms TCP connections for
/// newly-discovered backends.
///
/// Delegates to [`TcpConnectionPool::start_warmup_watcher`].
pub fn start_warmup(
&self,
instance_rx: tokio::sync::watch::Receiver<Vec<crate::component::Instance>>,
cancel_token: tokio_util::sync::CancellationToken,
) {
self.pool.start_warmup_watcher(instance_rx, cancel_token);
}
/// Parse TCP address from string
/// Supports formats: "host:port" or "tcp://host:port" or "host:port/endpoint_name"
/// Returns (SocketAddr, Option<endpoint_name>)
fn parse_address(address: &str) -> Result<(SocketAddr, Option<String>)> {
pub(crate) fn parse_address(address: &str) -> Result<(SocketAddr, Option<String>)> {
let addr_str = if let Some(stripped) = address.strip_prefix("tcp://") {
stripped
} else {
......@@ -507,25 +1377,21 @@ impl RequestPlaneClient for TcpRequestClient {
payload: Bytes,
mut headers: Headers,
) -> Result<Bytes> {
tracing::debug!("TCP client sending request to address: {address}");
tracing::debug!("TCP client sending request to address: {}", address);
self.stats.requests_sent.fetch_add(1, Ordering::Relaxed);
let payload_len = payload.len();
self.stats
.bytes_sent
.fetch_add(payload_len as u64, Ordering::Relaxed);
.fetch_add(payload.len() as u64, Ordering::Relaxed);
let (addr, endpoint_name) = Self::parse_address(&address)?;
// Add endpoint path to headers if present in address
if let Some(endpoint_name) = endpoint_name {
headers.insert("x-endpoint-path".to_string(), endpoint_name.clone());
}
// Get connection from pool (automatically filters unhealthy connections)
// Get shared connection from pool (Arc, not exclusive borrow)
let conn = self.pool.get_connection(addr).await?;
// Send request with timeout
// Note: The connection's send_request now handles all the async I/O via tasks
let result = tokio::time::timeout(
self.config.request_timeout,
conn.send_request(payload, &headers),
......@@ -541,17 +1407,13 @@ impl RequestPlaneClient for TcpRequestClient {
.bytes_received
.fetch_add(response.len() as u64, Ordering::Relaxed);
TCP_BYTES_RECEIVED_TOTAL.inc_by(response.len() as f64);
// Return connection to pool (health check happens inside)
self.pool.return_connection(conn).await;
// conn (Arc) dropped here -- connection stays in pool
Ok(response)
}
Ok(Err(e)) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed);
TCP_ERRORS_TOTAL.inc();
tracing::warn!("TCP request failed to {}: {}", addr, e);
// Don't return unhealthy connection to pool, let it drop
let cause = crate::error::DynamoError::from(
e.into_boxed_dyn_error() as Box<dyn std::error::Error + 'static>
);
......@@ -566,8 +1428,7 @@ impl RequestPlaneClient for TcpRequestClient {
Err(_) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed);
TCP_ERRORS_TOTAL.inc();
tracing::warn!("TCP request timeout to {addr}");
// Don't return timed-out connection to pool
tracing::warn!("TCP request timeout to {}", addr);
Err(anyhow::anyhow!(
crate::error::DynamoError::builder()
.error_type(crate::error::ErrorType::CannotConnect)
......@@ -586,15 +1447,24 @@ impl RequestPlaneClient for TcpRequestClient {
true // TCP client is always healthy if it was created successfully
}
fn start_warmup(
&self,
instance_rx: tokio::sync::watch::Receiver<Vec<crate::component::Instance>>,
cancel_token: tokio_util::sync::CancellationToken,
) {
TcpRequestClient::start_warmup(self, instance_rx, cancel_token);
}
fn stats(&self) -> ClientStats {
let (active, idle) = self.pool.connection_counts();
ClientStats {
requests_sent: self.stats.requests_sent.load(Ordering::Relaxed),
responses_received: self.stats.responses_received.load(Ordering::Relaxed),
errors: self.stats.errors.load(Ordering::Relaxed),
bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
active_connections: 0, // Could track this if needed
idle_connections: 0,
active_connections: active,
idle_connections: idle,
avg_latency_us: 0,
}
}
......@@ -604,6 +1474,7 @@ impl RequestPlaneClient for TcpRequestClient {
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use tokio::io::AsyncReadExt;
use tokio::net::TcpListener;
#[test]
......@@ -667,94 +1538,141 @@ mod tests {
assert!(client.is_healthy());
}
#[tokio::test]
async fn test_connection_health_check() {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage};
// Start a mock TCP server
/// Helper: spawn a mock TCP server that echoes requests.
/// Returns (listener_addr, connection_count_tracker).
async fn spawn_echo_server() -> (SocketAddr, Arc<AtomicUsize>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let conn_count = Arc::new(AtomicUsize::new(0));
let conn_count_clone = conn_count.clone();
// Spawn server that responds to requests
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
// Read path length and path
let mut len_buf = [0u8; 2];
read_half.read_exact(&mut len_buf).await.unwrap();
let path_len = u16::from_be_bytes(len_buf) as usize;
loop {
let result = listener.accept().await;
if result.is_err() {
break;
}
let (stream, _) = result.unwrap();
conn_count_clone.fetch_add(1, Ordering::SeqCst);
let mut path_buf = vec![0u8; path_len];
read_half.read_exact(&mut path_buf).await.unwrap();
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
// Read path length
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
if read_half.read_exact(&mut path_buf).await.is_err() {
break;
}
// Read headers length and headers
let mut headers_len_buf = [0u8; 2];
read_half.read_exact(&mut headers_len_buf).await.unwrap();
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
// Read headers length
let mut headers_len_buf = [0u8; 2];
if read_half.read_exact(&mut headers_len_buf).await.is_err() {
break;
}
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
if read_half.read_exact(&mut headers_buf).await.is_err() {
break;
}
// Read payload length + payload
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
if read_half.read_exact(&mut payload_buf).await.is_err() {
break;
}
// Send response
use crate::pipeline::network::codec::TcpResponseMessage;
let response = TcpResponseMessage::new(Bytes::from(payload_buf));
let encoded = response.encode().unwrap();
if write_half.write_all(&encoded).await.is_err() {
break;
}
}
});
}
});
(addr, conn_count)
}
#[tokio::test]
async fn test_connection_health_check() {
use crate::pipeline::network::codec::TcpResponseMessage;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
let mut len_buf = [0u8; 2];
read_half.read_exact(&mut len_buf).await.unwrap();
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
read_half.read_exact(&mut path_buf).await.unwrap();
let mut headers_len_buf = [0u8; 2];
read_half.read_exact(&mut headers_len_buf).await.unwrap();
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
read_half.read_exact(&mut headers_buf).await.unwrap();
// Read payload length and payload
let mut len_buf = [0u8; 4];
read_half.read_exact(&mut len_buf).await.unwrap();
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
read_half.read_exact(&mut payload_buf).await.unwrap();
// Send response
let response = TcpResponseMessage::new(Bytes::from_static(b"pong"));
let encoded = response.encode().unwrap();
write_half.write_all(&encoded).await.unwrap();
});
// Create connection
let conn = TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap();
assert!(conn.is_healthy(), "New connection should be healthy");
// Send a request
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let result = conn.send_request(Bytes::from("ping"), &headers).await;
assert!(result.is_ok(), "Request should succeed");
assert_eq!(result.unwrap(), Bytes::from("pong"));
assert!(
conn.is_healthy(),
"Connection should remain healthy after successful request"
);
}
#[tokio::test]
async fn test_concurrent_requests_single_connection() {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage};
use crate::pipeline::network::codec::TcpResponseMessage;
// Start a mock TCP server that handles multiple requests
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_clone = request_count.clone();
// Spawn server that responds to multiple requests
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
// Handle 5 requests
for _ in 0..5 {
// Read request
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
if read_half.read_exact(&mut path_buf).await.is_err() {
break;
......@@ -765,7 +1683,6 @@ mod tests {
break;
}
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
if read_half.read_exact(&mut headers_buf).await.is_err() {
break;
......@@ -776,7 +1693,6 @@ mod tests {
break;
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
if read_half.read_exact(&mut payload_buf).await.is_err() {
break;
......@@ -784,7 +1700,6 @@ mod tests {
request_count_clone.fetch_add(1, Ordering::SeqCst);
// Send response
let response = TcpResponseMessage::new(Bytes::from(payload_buf));
let encoded = response.encode().unwrap();
if write_half.write_all(&encoded).await.is_err() {
......@@ -793,14 +1708,12 @@ mod tests {
}
});
// Create connection
let conn = Arc::new(
TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap(),
);
// Send 5 concurrent requests
let mut handles = vec![];
for i in 0..5 {
let conn = conn.clone();
......@@ -815,7 +1728,6 @@ mod tests {
handles.push(handle);
}
// Wait for all requests to complete
let mut results = vec![];
for handle in handles {
let result = handle.await.unwrap();
......@@ -823,10 +1735,7 @@ mod tests {
results.push(result.unwrap());
}
// Verify all requests got responses
assert_eq!(results.len(), 5);
// Verify server received all requests
assert_eq!(
request_count.load(Ordering::SeqCst),
5,
......@@ -835,66 +1744,279 @@ mod tests {
}
#[tokio::test]
async fn test_connection_pool_reuse() {
use crate::pipeline::network::codec::TcpResponseMessage;
async fn test_lru_connection_reuse() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 4,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
// Get connection twice sequentially -- should reuse the same TCP connection
let conn1 = pool.get_connection(addr).await.unwrap();
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let _ = conn1
.send_request(Bytes::from("ping1"), &headers)
.await
.unwrap();
drop(conn1); // Arc ref dropped, but conn stays in pool
let conn2 = pool.get_connection(addr).await.unwrap();
let _ = conn2
.send_request(Bytes::from("ping2"), &headers)
.await
.unwrap();
drop(conn2);
assert_eq!(
conn_count.load(Ordering::SeqCst),
1,
"Should reuse connection from pool (1 TCP connection total)"
);
}
#[tokio::test]
async fn test_lru_eviction_keeps_inflight_alive() {
let (addr, _conn_count) = spawn_echo_server().await;
// Pool size 1 so we can evict easily
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 1,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
// Get a connection and hold the Arc
let conn = pool.get_connection(addr).await.unwrap();
// Start a mock TCP server
// Mark it unhealthy to force the pool to create a new one
conn.healthy.store(false, Ordering::Relaxed);
// Getting another connection should create a new one (old one evicted from LRU)
let conn2 = pool.get_connection(addr).await.unwrap();
assert!(conn2.is_healthy());
// Original conn Arc is still alive (not dropped) -- it just can't be used
assert!(!conn.is_healthy());
// The Arc keeps the resources alive even though LRU evicted it
assert!(Arc::strong_count(&conn.writer_handle) >= 1);
}
#[tokio::test]
async fn test_high_concurrency_bounded_connections() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let mut handles = vec![];
for i in 0..500 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
let mut ok_count = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok_count += 1;
}
}
let total_conns = conn_count.load(Ordering::SeqCst);
assert!(
total_conns <= 2,
"Should create at most pool_size (2) connections, got {}",
total_conns
);
assert!(ok_count > 0, "At least some requests should succeed");
}
#[tokio::test]
async fn test_thundering_herd_cold_start() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 4,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
// 100 tasks all racing on a cold pool
let mut handles = vec![];
for i in 0..100 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
for handle in handles {
let _ = handle.await.unwrap();
}
let total_conns = conn_count.load(Ordering::SeqCst);
assert!(
total_conns <= 4,
"Thundering herd: should create at most pool_size (4) connections, got {}",
total_conns
);
}
#[tokio::test]
async fn test_server_crash_recovery() {
// Start a server we can kill
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connection_count = Arc::new(AtomicUsize::new(0));
let connection_count_clone = connection_count.clone();
let cancel = tokio_util::sync::CancellationToken::new();
let cancel_clone = cancel.clone();
// Spawn server that accepts multiple connections
let server_task = tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
if let Ok((stream, _)) = result {
let cancel = cancel_clone.clone();
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
tokio::select! {
_ = cancel.cancelled() => break,
result = async {
let mut len_buf = [0u8; 2];
read_half.read_exact(&mut len_buf).await?;
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut buf = vec![0u8; path_len];
read_half.read_exact(&mut buf).await?;
let mut hlen = [0u8; 2];
read_half.read_exact(&mut hlen).await?;
let hl = u16::from_be_bytes(hlen) as usize;
let mut hbuf = vec![0u8; hl];
read_half.read_exact(&mut hbuf).await?;
let mut plen = [0u8; 4];
read_half.read_exact(&mut plen).await?;
let pl = u32::from_be_bytes(plen) as usize;
let mut pbuf = vec![0u8; pl];
read_half.read_exact(&mut pbuf).await?;
use crate::pipeline::network::codec::TcpResponseMessage;
let resp = TcpResponseMessage::new(Bytes::from(pbuf));
let enc = resp.encode()?;
write_half.write_all(&enc).await?;
Ok::<_, anyhow::Error>(())
} => {
if result.is_err() { break; }
}
}
}
});
}
}
_ = cancel_clone.cancelled() => break,
}
}
});
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(2),
connect_timeout: Duration::from_secs(2),
pool_size: 2,
channel_buffer: 10,
};
let pool = Arc::new(TcpConnectionPool::new(config));
// Use pool successfully
let conn = pool.get_connection(addr).await.unwrap();
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let result = conn
.send_request(Bytes::from("before_crash"), &headers)
.await;
assert!(result.is_ok());
drop(conn);
// Kill the server
cancel.cancel();
let _ = server_task.await;
// Requests should fail (server is gone)
tokio::time::sleep(Duration::from_millis(50)).await;
let conn = pool.get_connection(addr).await;
// Pool should either fail to connect or return an unhealthy conn
if let Ok(conn) = conn {
let result = conn
.send_request(Bytes::from("after_crash"), &headers)
.await;
// Either send fails or connection is unhealthy
assert!(result.is_err() || !conn.is_healthy());
}
// Start a new server on the same port
let listener2 = TcpListener::bind(addr).await.unwrap();
tokio::spawn(async move {
loop {
let result = listener.accept().await;
let result = listener2.accept().await;
if result.is_err() {
break;
}
let (stream, _) = result.unwrap();
connection_count_clone.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
// Read request
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut path_buf = vec![0u8; path_len];
if read_half.read_exact(&mut path_buf).await.is_err() {
let mut buf = vec![0u8; path_len];
if read_half.read_exact(&mut buf).await.is_err() {
break;
}
let mut headers_len_buf = [0u8; 2];
if read_half.read_exact(&mut headers_len_buf).await.is_err() {
let mut hlen = [0u8; 2];
if read_half.read_exact(&mut hlen).await.is_err() {
break;
}
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
if read_half.read_exact(&mut headers_buf).await.is_err() {
let hl = u16::from_be_bytes(hlen) as usize;
let mut hbuf = vec![0u8; hl];
if read_half.read_exact(&mut hbuf).await.is_err() {
break;
}
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
let mut plen = [0u8; 4];
if read_half.read_exact(&mut plen).await.is_err() {
break;
}
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload_buf = vec![0u8; payload_len];
if read_half.read_exact(&mut payload_buf).await.is_err() {
let pl = u32::from_be_bytes(plen) as usize;
let mut pbuf = vec![0u8; pl];
if read_half.read_exact(&mut pbuf).await.is_err() {
break;
}
// Send response
let response = TcpResponseMessage::new(Bytes::from_static(b"ok"));
let encoded = response.encode().unwrap();
if write_half.write_all(&encoded).await.is_err() {
use crate::pipeline::network::codec::TcpResponseMessage;
let resp = TcpResponseMessage::new(Bytes::from(pbuf));
let enc = resp.encode().unwrap();
if write_half.write_all(&enc).await.is_err() {
break;
}
}
......@@ -902,28 +2024,350 @@ mod tests {
}
});
// Pool should heal: get new connection and succeed
tokio::time::sleep(Duration::from_millis(100)).await;
let conn = pool.get_connection(addr).await.unwrap();
let result = conn
.send_request(Bytes::from("after_recovery"), &headers)
.await;
assert!(result.is_ok(), "Pool should heal after server recovery");
}
#[tokio::test]
async fn test_pool_scales_under_pressure() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 4,
channel_buffer: 1, // Very small buffer to force saturation quickly
};
let pool = Arc::new(TcpConnectionPool::new(config));
// Send enough concurrent requests to saturate channel_buffer=1
let mut handles = vec![];
for i in 0..20 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
for handle in handles {
let _ = handle.await.unwrap();
}
let total_conns = conn_count.load(Ordering::SeqCst);
assert!(
total_conns > 1,
"Pool should scale beyond 1 connection under pressure, got {}",
total_conns
);
assert!(
total_conns <= 4,
"Pool should not exceed pool_size (4), got {}",
total_conns
);
}
#[tokio::test]
async fn test_pool_size_cap_sustained_load() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 3,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
// 3 rounds of 200 requests
for round in 0..3 {
let mut handles = vec![];
for i in 0..200 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("round_{}_req_{}", round, i)), &headers)
.await
}));
}
for handle in handles {
let _ = handle.await.unwrap();
}
}
let total_conns = conn_count.load(Ordering::SeqCst);
assert!(
total_conns <= 3,
"Sustained load should not exceed pool_size (3), got {}",
total_conns
);
}
#[tokio::test]
async fn test_backpressure_small_channel() {
let (addr, _conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 1,
channel_buffer: 1,
};
let pool = Arc::new(TcpConnectionPool::new(config));
// 50 requests through pool_size=1 buffer=1 -- all should complete via backpressure
let mut handles = vec![];
for i in 0..50 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
let mut ok_count = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok_count += 1;
}
}
assert_eq!(
ok_count, 50,
"All 50 requests should complete under backpressure"
);
}
#[tokio::test]
async fn test_no_recursive_retry_under_connect_contention() {
// This test verifies that connect contention uses bounded retries, not recursion.
// We test this by having many tasks race on a cold pool with pool_size=1.
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 1,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
// Many tasks racing, only one should connect
let mut handles = vec![];
for _ in 0..50 {
let pool = pool.clone();
handles.push(tokio::spawn(async move { pool.get_connection(addr).await }));
}
let mut ok = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok += 1;
}
}
// All should succeed (one connects, others retry via hot path)
assert!(ok > 0, "At least some tasks should get connections");
assert_eq!(
conn_count.load(Ordering::SeqCst),
1,
"Only 1 TCP connection should be created"
);
}
#[tokio::test]
async fn test_global_connect_limiter_multi_host() {
// Spawn servers on 4 different ports to simulate multiple hosts
let mut addrs = vec![];
let total_conns = Arc::new(AtomicUsize::new(0));
for _ in 0..4 {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
addrs.push(addr);
let total_conns = total_conns.clone();
tokio::spawn(async move {
loop {
let result = listener.accept().await;
if result.is_err() {
break;
}
let (stream, _) = result.unwrap();
total_conns.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let (mut read_half, mut write_half) = tokio::io::split(stream);
loop {
let mut len_buf = [0u8; 2];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
}
let path_len = u16::from_be_bytes(len_buf) as usize;
let mut buf = vec![0u8; path_len];
if read_half.read_exact(&mut buf).await.is_err() {
break;
}
let mut hlen = [0u8; 2];
if read_half.read_exact(&mut hlen).await.is_err() {
break;
}
let hl = u16::from_be_bytes(hlen) as usize;
let mut hbuf = vec![0u8; hl];
if read_half.read_exact(&mut hbuf).await.is_err() {
break;
}
let mut plen = [0u8; 4];
if read_half.read_exact(&mut plen).await.is_err() {
break;
}
let pl = u32::from_be_bytes(plen) as usize;
let mut pbuf = vec![0u8; pl];
if read_half.read_exact(&mut pbuf).await.is_err() {
break;
}
use crate::pipeline::network::codec::TcpResponseMessage;
let resp = TcpResponseMessage::new(Bytes::from(pbuf));
let enc = resp.encode().unwrap();
if write_half.write_all(&enc).await.is_err() {
break;
}
}
});
}
});
}
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 50,
};
let pool = Arc::new(TcpConnectionPool::new(config));
// Hit all 4 hosts concurrently
let mut handles = vec![];
for addr in &addrs {
let pool = pool.clone();
let addr = *addr;
for i in 0..10 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
let conn = pool.get_connection(addr).await?;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("req_{}", i)), &headers)
.await
}));
}
}
let mut ok_count = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok_count += 1;
}
}
assert!(
ok_count > 0,
"Requests across multiple hosts should succeed"
);
// Total connections across all hosts should be bounded
let tc = total_conns.load(Ordering::SeqCst);
assert!(
tc <= 8,
"Total connections across 4 hosts should be <= 4*pool_size(2)=8, got {}",
tc
);
}
#[tokio::test]
async fn test_idle_host_pool_cleanup() {
let (addr, _conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
// Override TTL to 0 for testing
let pool = TcpConnectionPool {
host_idle_ttl_ms: 0,
..pool
};
// Create a connection to populate the host entry
let conn = pool.get_connection(addr).await.unwrap();
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let _ = conn.send_request(Bytes::from("test"), &headers).await;
drop(conn);
assert!(pool.hosts.contains_key(&addr), "Host entry should exist");
// Wait a tiny bit so the timestamp is stale
tokio::time::sleep(Duration::from_millis(10)).await;
pool.cleanup_idle_hosts();
assert!(
!pool.hosts.contains_key(&addr),
"Idle host entry should be cleaned up"
);
}
#[tokio::test]
async fn test_connection_pool_reuse() {
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 2,
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config);
// Get connection twice from pool
let conn1 = pool.get_connection(addr).await.unwrap();
pool.return_connection(conn1).await;
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let _ = conn1
.send_request(Bytes::from("test1"), &headers)
.await
.unwrap();
drop(conn1);
// Small delay to ensure connection is returned
tokio::time::sleep(Duration::from_millis(10)).await;
let conn2 = pool.get_connection(addr).await.unwrap();
pool.return_connection(conn2).await;
let _ = conn2
.send_request(Bytes::from("test2"), &headers)
.await
.unwrap();
drop(conn2);
// Should have created only 1 TCP connection since we reused
assert_eq!(
connection_count.load(Ordering::SeqCst),
conn_count.load(Ordering::SeqCst),
1,
"Should reuse connection from pool"
);
......@@ -937,7 +2381,7 @@ mod tests {
// Server that immediately closes connections
tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
drop(stream); // Immediately close
drop(stream);
}
});
......@@ -948,25 +2392,178 @@ mod tests {
channel_buffer: 10,
};
let pool = TcpConnectionPool::new(config.clone());
// Try to get a connection - it will become unhealthy quickly
let result =
TcpConnection::connect(addr, config.connect_timeout, config.channel_buffer).await;
if let Ok(conn) = result {
// Mark as unhealthy by trying to use it
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
let _ = conn.send_request(Bytes::from("test"), &headers).await;
let result = tokio::time::timeout(
Duration::from_millis(250),
conn.send_request(Bytes::from("test"), &headers),
)
.await;
assert!(
result.is_ok(),
"send_request should fail promptly when the peer closes cleanly"
);
assert!(
result.unwrap().is_err(),
"clean peer close should surface as a request error"
);
// Connection should become unhealthy after server drops it
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(
!conn.is_healthy(),
"Connection should become unhealthy after peer closes it"
);
}
}
// Return to pool
pool.return_connection(conn).await;
#[tokio::test]
async fn test_warmup_pre_connects_on_instance_discovery() {
use crate::component::{Instance, TransportType};
// Try to get from pool again - should get a new connection attempt
let result2 = pool.get_connection(addr).await;
// This might fail or succeed depending on timing, but should not panic
let _ = result2;
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(5),
pool_size: 4,
channel_buffer: 10,
};
let pool = Arc::new(TcpConnectionPool::new(config));
let cancel_token = tokio_util::sync::CancellationToken::new();
// Create a watch channel with no instances initially
let (instance_tx, instance_rx) = tokio::sync::watch::channel::<Vec<Instance>>(Vec::new());
// Start the warmup watcher
pool.start_warmup_watcher(instance_rx, cancel_token.clone());
// Let the watcher task start and begin polling changed()
tokio::time::sleep(Duration::from_millis(50)).await;
// No connections yet
assert_eq!(conn_count.load(Ordering::SeqCst), 0);
// Discover a new TCP backend
let tcp_addr = format!("{}:{}/test_endpoint", addr.ip(), addr.port());
instance_tx
.send(vec![Instance {
component: "test".to_string(),
endpoint: "test_ep".to_string(),
namespace: "default".to_string(),
instance_id: 1,
transport: TransportType::Tcp(tcp_addr),
device_type: None,
}])
.unwrap();
// Give the warmup watcher time to process and connect
tokio::time::sleep(Duration::from_millis(200)).await;
// Should have pre-connected
assert_eq!(
conn_count.load(Ordering::SeqCst),
1,
"Warmup should have created 1 connection to the newly discovered backend"
);
// Pool should have the host entry
assert!(
pool.hosts.contains_key(&addr),
"Pool should have a host entry for the warmed address"
);
cancel_token.cancel();
}
#[tokio::test]
async fn test_closed_connection_after_enqueue_fails_promptly() {
let (addr, _conn_count) = spawn_echo_server().await;
let mut conn = TcpConnection::connect(addr, Duration::from_secs(5), 10)
.await
.unwrap();
let barrier = Arc::new(tokio::sync::Barrier::new(2));
conn.post_enqueue_barrier = Some(barrier.clone());
let conn = Arc::new(conn);
let request = {
let conn = conn.clone();
tokio::spawn(async move {
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
tokio::time::timeout(
Duration::from_millis(200),
conn.send_request(Bytes::from("queued_after_close"), &headers),
)
.await
})
};
// Wait until the request is enqueued, then simulate the writer entering
// its terminal cleanup path before it can process the new item.
barrier.wait().await;
conn.closed.store(true, Ordering::Release);
conn.healthy.store(false, Ordering::Relaxed);
barrier.wait().await;
let result = request.await.unwrap();
assert!(
result.is_ok(),
"send_request should fail promptly instead of waiting for request_timeout"
);
let inner = result.unwrap();
assert!(inner.is_err(), "closed connection should return an error");
conn.writer_handle.abort();
conn.reader_handle.abort();
}
#[tokio::test]
async fn test_lockfree_submit_and_batch() {
// Verify that concurrent submits to the same connection produce
// batched writes (batch_size > 1) by checking that all requests
// complete correctly even when submitted simultaneously.
let (addr, conn_count) = spawn_echo_server().await;
let config = TcpRequestConfig {
request_timeout: Duration::from_secs(10),
connect_timeout: Duration::from_secs(5),
pool_size: 1,
channel_buffer: 200,
};
let pool = Arc::new(TcpConnectionPool::new(config));
// Force a single connection, then blast concurrent requests
let conn = pool.get_connection(addr).await.unwrap();
let mut handles = vec![];
for i in 0..100 {
let conn = conn.clone();
handles.push(tokio::spawn(async move {
let mut headers = Headers::new();
headers.insert("x-endpoint-path".to_string(), "test".to_string());
conn.send_request(Bytes::from(format!("batch_req_{}", i)), &headers)
.await
}));
}
let mut ok_count = 0;
for handle in handles {
if handle.await.unwrap().is_ok() {
ok_count += 1;
}
}
assert_eq!(ok_count, 100, "All 100 concurrent requests should succeed");
assert_eq!(
conn_count.load(Ordering::SeqCst),
1,
"Should use only 1 connection"
);
}
}
......@@ -109,6 +109,16 @@ pub trait RequestPlaneClient: Send + Sync {
ClientStats::default()
}
/// Start a background task that eagerly warms connections for newly-discovered backends.
/// Only TCP overrides this; HTTP and NATS clients inherit the no-op.
fn start_warmup(
&self,
_instance_rx: tokio::sync::watch::Receiver<Vec<crate::component::Instance>>,
_cancel_token: tokio_util::sync::CancellationToken,
) {
// No-op default
}
/// Close the client gracefully (optional)
///
/// Implementations should:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment