Unverified Commit bbb79afb authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: use zero copy decoder for handling high concurrency / request bursts for tcp ingress (#5376)

parent 1da603a4
...@@ -15,8 +15,10 @@ use tokio_util::{ ...@@ -15,8 +15,10 @@ use tokio_util::{
}; };
mod two_part; mod two_part;
pub mod zero_copy_decoder;
pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType}; pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
pub use zero_copy_decoder::{TcpRequestMessageZeroCopy, ZeroCopyTcpDecoder};
/// TCP request plane protocol message with endpoint routing and trace headers /// TCP request plane protocol message with endpoint routing and trace headers
/// ///
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Zero-copy TCP message decoder for high-concurrency scenarios
//!
//! This decoder eliminates message reconstruction copies by:
//! 1. Reading into a reusable buffer
//! 2. Parsing headers in-place
//! 3. Splitting off exact message sizes (zero-copy via Bytes::split_to)
//! 4. Returning Arc-counted Bytes that can be cloned cheaply
use bytes::{Buf, Bytes, BytesMut};
use std::io;
use tokio::io::{AsyncRead, AsyncReadExt};
/// Maximum message size (32MB default, configurable via env)
const MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024; // 32MB
const INITIAL_BUFFER_SIZE: usize = 262144; // 256KB
fn get_max_message_size() -> usize {
std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(MAX_MESSAGE_SIZE)
}
/// Zero-copy streaming decoder that reuses buffers
///
/// This decoder maintains an internal buffer and only allocates when necessary.
/// Messages are returned as Arc-counted Bytes slices, making cloning extremely cheap.
pub struct ZeroCopyTcpDecoder {
/// Reusable read buffer - grows as needed but never shrinks
read_buffer: BytesMut,
/// Maximum allowed message size
max_message_size: usize,
}
impl ZeroCopyTcpDecoder {
/// Create a new decoder with default buffer size
pub fn new() -> Self {
Self::with_capacity(INITIAL_BUFFER_SIZE)
}
/// Create a new decoder with specific initial capacity
pub fn with_capacity(capacity: usize) -> Self {
Self {
read_buffer: BytesMut::with_capacity(capacity),
max_message_size: get_max_message_size(),
}
}
/// Read one complete message with ZERO copies
///
/// This method:
/// 1. Ensures headers are buffered
/// 2. Parses headers in-place (no allocation)
/// 3. Ensures entire message is buffered
/// 4. Splits off exact message size (zero-copy pointer arithmetic)
/// 5. Returns Arc-counted Bytes (cheap to clone)
pub async fn read_message<R: AsyncRead + Unpin>(
&mut self,
reader: &mut R,
) -> io::Result<TcpRequestMessageZeroCopy> {
// Ensure we have at least enough bytes to start parsing
// Wire format: [path_len(2)][path][headers_len(2)][headers][payload_len(4)][payload]
const MIN_HEADER_SIZE: usize = 2;
// Fill buffer if needed
while self.read_buffer.len() < MIN_HEADER_SIZE {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
if self.read_buffer.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed",
));
} else {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete message header",
));
}
}
}
// Parse endpoint path length (first 2 bytes) - NO COPY
let path_len = u16::from_be_bytes([self.read_buffer[0], self.read_buffer[1]]) as usize;
// Sanity check path length
if path_len == 0 || path_len > 1024 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("invalid endpoint path length: {}", path_len),
));
}
// Ensure we have path + headers_len
let initial_header_size = 2 + path_len + 2; // path_len(2) + path + headers_len(2)
while self.read_buffer.len() < initial_header_size {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete message header",
));
}
}
// Parse headers length (2 bytes after path) - NO COPY
let headers_len_offset = 2 + path_len;
let headers_len = u16::from_be_bytes([
self.read_buffer[headers_len_offset],
self.read_buffer[headers_len_offset + 1],
]) as usize;
// Ensure we have headers + payload length
let full_header_size = 2 + path_len + 2 + headers_len + 4; // path_len(2) + path + headers_len(2) + headers + payload_len(4)
while self.read_buffer.len() < full_header_size {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete message header",
));
}
}
// Parse payload length (4 bytes after headers) - NO COPY
let payload_len_offset = 2 + path_len + 2 + headers_len;
let payload_len = u32::from_be_bytes([
self.read_buffer[payload_len_offset],
self.read_buffer[payload_len_offset + 1],
self.read_buffer[payload_len_offset + 2],
self.read_buffer[payload_len_offset + 3],
]) as usize;
// Calculate total message size
let total_len = 2 + path_len + 2 + headers_len + 4 + payload_len;
// Sanity check total message length (including all overhead)
if total_len > self.max_message_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"message too large: {} bytes (max: {} bytes)",
total_len, self.max_message_size
),
));
}
// Ensure entire message is buffered
while self.read_buffer.len() < total_len {
let n = reader.read_buf(&mut self.read_buffer).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"incomplete message: expected {} bytes, got {}",
total_len,
self.read_buffer.len()
),
));
}
}
// Split off exactly what we need - ZERO COPY!
// split_to() just advances the internal pointer, doesn't allocate or copy
let message_bytes = self.read_buffer.split_to(total_len).freeze();
// Return zero-copy message wrapper
Ok(TcpRequestMessageZeroCopy::new(message_bytes))
}
/// Get the current buffer capacity
pub fn buffer_capacity(&self) -> usize {
self.read_buffer.capacity()
}
/// Get the current buffered data size
pub fn buffered_len(&self) -> usize {
self.read_buffer.len()
}
}
impl Default for ZeroCopyTcpDecoder {
fn default() -> Self {
Self::new()
}
}
/// Zero-copy message representation
///
/// This struct holds an Arc-counted Bytes buffer containing the entire message.
/// All accessors return zero-copy slices or references into this buffer.
#[derive(Clone)]
pub struct TcpRequestMessageZeroCopy {
/// Entire message as Arc-counted buffer
/// Format: [path_len(2)][path(var)][headers_len(2)][headers(var)][payload_len(4)][payload(var)]
raw: Bytes,
}
impl TcpRequestMessageZeroCopy {
/// Create a new zero-copy message from raw bytes
fn new(raw: Bytes) -> Self {
Self { raw }
}
/// Get the endpoint path length
#[inline]
fn path_len(&self) -> usize {
u16::from_be_bytes([self.raw[0], self.raw[1]]) as usize
}
/// Get endpoint path as a string slice (zero-copy)
///
/// This returns a reference into the message buffer, no allocation.
pub fn endpoint_path(&self) -> Result<&str, std::str::Utf8Error> {
let path_len = self.path_len();
std::str::from_utf8(&self.raw[2..2 + path_len])
}
/// Get endpoint path as bytes (zero-copy)
pub fn endpoint_path_bytes(&self) -> &[u8] {
let path_len = self.path_len();
&self.raw[2..2 + path_len]
}
/// Get the headers length
#[inline]
fn headers_len(&self) -> usize {
let path_len = self.path_len();
let offset = 2 + path_len;
u16::from_be_bytes([self.raw[offset], self.raw[offset + 1]]) as usize
}
/// Get headers as bytes (zero-copy)
pub fn headers_bytes(&self) -> &[u8] {
let path_len = self.path_len();
let headers_len = self.headers_len();
let headers_start = 2 + path_len + 2;
&self.raw[headers_start..headers_start + headers_len]
}
/// Get headers as a HashMap (requires parsing)
pub fn headers(&self) -> std::collections::HashMap<String, String> {
let headers_bytes = self.headers_bytes();
if headers_bytes.is_empty() {
return std::collections::HashMap::new();
}
// Parse headers from JSON format
serde_json::from_slice(headers_bytes).unwrap_or_default()
}
/// Get the payload length
#[inline]
fn payload_len(&self) -> usize {
let path_len = self.path_len();
let headers_len = self.headers_len();
let offset = 2 + path_len + 2 + headers_len;
u32::from_be_bytes([
self.raw[offset],
self.raw[offset + 1],
self.raw[offset + 2],
self.raw[offset + 3],
]) as usize
}
/// Get payload as zero-copy Bytes
///
/// This returns an Arc-counted slice of the message buffer.
/// Cloning the returned Bytes is extremely cheap (just Arc clone).
pub fn payload(&self) -> Bytes {
let path_len = self.path_len();
let headers_len = self.headers_len();
let payload_start = 2 + path_len + 2 + headers_len + 4;
self.raw.slice(payload_start..) // ZERO COPY! Just Arc clone + offset
}
/// Get total message size in bytes
pub fn total_size(&self) -> usize {
self.raw.len()
}
/// Get the raw message bytes (for debugging)
pub fn raw_bytes(&self) -> &Bytes {
&self.raw
}
}
impl std::fmt::Debug for TcpRequestMessageZeroCopy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpRequestMessageZeroCopy")
.field("total_size", &self.total_size())
.field("endpoint_path", &self.endpoint_path().ok())
.field("payload_len", &self.payload_len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn test_zero_copy_decoder_basic() {
// Create a test message with headers
let endpoint = "test/endpoint";
let payload = b"Hello, World!";
let headers: Vec<u8> = vec![]; // Empty headers
let mut message = Vec::new();
// path_len + path
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
// headers_len + headers
message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
message.extend_from_slice(&headers);
// payload_len + payload
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(payload);
// Create a mock reader
let mut reader = &message[..];
// Decode
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
// Verify
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.total_size(), message.len());
assert_eq!(msg.headers().len(), 0); // Empty headers
}
#[tokio::test]
async fn test_zero_copy_decoder_large_payload() {
// Create a large payload (200KB)
let endpoint = "large/endpoint";
let payload = vec![0x42u8; 200 * 1024];
let headers: Vec<u8> = vec![]; // Empty headers
let mut message = Vec::new();
// path_len + path
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
// headers_len + headers
message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
message.extend_from_slice(&headers);
// payload_len + payload
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(&payload);
let mut reader = &message[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().len(), payload.len());
}
#[tokio::test]
async fn test_zero_copy_decoder_total_size_limit() {
// Test that the decoder validates total message size, not just payload size
// Create a message where total_len exceeds max but payload alone might not
let max_size = 1024; // 1KB limit
let mut decoder = ZeroCopyTcpDecoder::with_capacity(256);
decoder.max_message_size = max_size;
// Create a message that exceeds the limit with overhead included
let endpoint = "test/endpoint";
let payload = vec![0x42u8; max_size]; // Payload equals max
let headers: Vec<u8> = vec![]; // Empty headers
let mut message = Vec::new();
// path_len + path
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
// headers_len + headers
message.extend_from_slice(&(headers.len() as u16).to_be_bytes());
message.extend_from_slice(&headers);
// payload_len + payload
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(&payload);
// total_len = 2 + 13 + 2 + 0 + 4 + 1024 = 1045 bytes > 1024 max
let mut reader = &message[..];
let result = decoder.read_message(&mut reader).await;
// Should fail with InvalidData error
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("message too large"));
assert!(err.to_string().contains("1045")); // total_len
assert!(err.to_string().contains("1024")); // max_message_size
}
#[tokio::test]
async fn test_zero_copy_decoder_with_headers() {
// Test header parsing with actual header data
let endpoint = "api/v1/inference";
let payload = b"Request payload data";
// Create mock headers as JSON
let mut headers_map = std::collections::HashMap::new();
headers_map.insert("traceparent".to_string(), "00-abc123-def456-01".to_string());
headers_map.insert("user-agent".to_string(), "test-client/1.0".to_string());
headers_map.insert("request-id".to_string(), "req-12345".to_string());
let headers_json = serde_json::to_vec(&headers_map).unwrap();
let mut message = Vec::new();
// path_len + path
message.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message.extend_from_slice(endpoint.as_bytes());
// headers_len + headers (non-empty this time)
message.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
message.extend_from_slice(&headers_json);
// payload_len + payload
message.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message.extend_from_slice(payload);
// Decode the message
let mut reader = &message[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
// Verify endpoint
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
// Verify payload
assert_eq!(msg.payload().as_ref(), payload);
// Verify total size includes all components
assert_eq!(msg.total_size(), message.len());
// Verify headers are correctly parsed
let decoded_headers = msg.headers();
assert_eq!(decoded_headers.len(), 3);
assert_eq!(
decoded_headers.get("traceparent").unwrap(),
"00-abc123-def456-01"
);
assert_eq!(
decoded_headers.get("user-agent").unwrap(),
"test-client/1.0"
);
assert_eq!(decoded_headers.get("request-id").unwrap(), "req-12345");
// Verify headers_bytes returns the raw JSON
let headers_bytes = msg.headers_bytes();
assert_eq!(headers_bytes, &headers_json[..]);
}
#[tokio::test]
async fn test_zero_copy_decoder_empty_vs_populated_headers() {
// Test both empty and populated headers in sequence to ensure proper parsing
let endpoint = "test/endpoint";
let payload = b"test data";
// Test 1: Empty headers
let mut message_empty = Vec::new();
message_empty.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message_empty.extend_from_slice(endpoint.as_bytes());
message_empty.extend_from_slice(&(0u16).to_be_bytes()); // headers_len = 0
// No headers bytes
message_empty.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message_empty.extend_from_slice(payload);
let mut reader = &message_empty[..];
let mut decoder = ZeroCopyTcpDecoder::new();
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.headers().len(), 0);
assert_eq!(msg.headers_bytes().len(), 0);
// Test 2: Populated headers with same decoder
let mut headers_map = std::collections::HashMap::new();
headers_map.insert("x-test-header".to_string(), "test-value".to_string());
let headers_json = serde_json::to_vec(&headers_map).unwrap();
let mut message_with_headers = Vec::new();
message_with_headers.extend_from_slice(&(endpoint.len() as u16).to_be_bytes());
message_with_headers.extend_from_slice(endpoint.as_bytes());
message_with_headers.extend_from_slice(&(headers_json.len() as u16).to_be_bytes());
message_with_headers.extend_from_slice(&headers_json);
message_with_headers.extend_from_slice(&(payload.len() as u32).to_be_bytes());
message_with_headers.extend_from_slice(payload);
let mut reader = &message_with_headers[..];
let msg = decoder.read_message(&mut reader).await.unwrap();
assert_eq!(msg.endpoint_path().unwrap(), endpoint);
assert_eq!(msg.payload().as_ref(), payload);
assert_eq!(msg.headers().len(), 1);
assert_eq!(msg.headers().get("x-test-header").unwrap(), "test-value");
}
}
...@@ -25,6 +25,13 @@ use tracing::Instrument; ...@@ -25,6 +25,13 @@ use tracing::Instrument;
/// Default maximum message size for TCP server (32 MB) /// Default maximum message size for TCP server (32 MB)
const DEFAULT_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024; const DEFAULT_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
/// Default worker pool size for TCP request handling
const DEFAULT_WORKER_POOL_SIZE: usize = 1500;
/// Default work queue size for TCP request handling
/// this is 4X the worker pool size to handle burst traffic
const DEFAULT_WORK_QUEUE_SIZE: usize = 6000;
/// Get maximum message size from environment or use default /// Get maximum message size from environment or use default
fn get_max_message_size() -> usize { fn get_max_message_size() -> usize {
std::env::var("DYN_TCP_MAX_MESSAGE_SIZE") std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
...@@ -33,6 +40,35 @@ fn get_max_message_size() -> usize { ...@@ -33,6 +40,35 @@ fn get_max_message_size() -> usize {
.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE) .unwrap_or(DEFAULT_MAX_MESSAGE_SIZE)
} }
/// Get worker pool size from environment or use default
fn get_worker_pool_size() -> usize {
std::env::var("DYN_TCP_WORKER_POOL_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_WORKER_POOL_SIZE)
}
/// Get work queue size from environment or use default
fn get_work_queue_size() -> usize {
std::env::var("DYN_TCP_WORK_QUEUE_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_WORK_QUEUE_SIZE)
}
/// Work item for the worker pool
struct WorkItem {
service_handler: Arc<dyn PushWorkHandler>,
payload: Bytes,
headers: std::collections::HashMap<String, String>,
inflight: Arc<AtomicU64>,
notify: Arc<Notify>,
instance_id: u64,
namespace: String,
component_name: String,
endpoint_name: String,
}
/// Shared TCP server that handles multiple endpoints on a single port /// Shared TCP server that handles multiple endpoints on a single port
pub struct SharedTcpServer { pub struct SharedTcpServer {
handlers: Arc<DashMap<String, Arc<EndpointHandler>>>, handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
...@@ -41,6 +77,8 @@ pub struct SharedTcpServer { ...@@ -41,6 +77,8 @@ pub struct SharedTcpServer {
/// The actual bound address (populated after bind_and_start, contains actual port) /// The actual bound address (populated after bind_and_start, contains actual port)
actual_addr: RwLock<Option<SocketAddr>>, actual_addr: RwLock<Option<SocketAddr>>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
/// Channel for sending work to the worker pool
work_tx: tokio::sync::mpsc::Sender<WorkItem>,
} }
struct EndpointHandler { struct EndpointHandler {
...@@ -56,6 +94,21 @@ struct EndpointHandler { ...@@ -56,6 +94,21 @@ struct EndpointHandler {
impl SharedTcpServer { impl SharedTcpServer {
pub fn new(bind_addr: SocketAddr, cancellation_token: CancellationToken) -> Arc<Self> { pub fn new(bind_addr: SocketAddr, cancellation_token: CancellationToken) -> Arc<Self> {
let worker_pool_size = get_worker_pool_size();
let work_queue_size = get_work_queue_size();
tracing::info!(
"Initializing TCP server with dispatcher (concurrency={}, queue={})",
worker_pool_size,
work_queue_size
);
// Create bounded channel for work items
let (work_tx, work_rx) = tokio::sync::mpsc::channel(work_queue_size);
// Start worker pool
Self::start_worker_pool(worker_pool_size, work_rx, cancellation_token.clone());
Arc::new(Self { Arc::new(Self {
handlers: Arc::new(DashMap::new()), handlers: Arc::new(DashMap::new()),
// address we requested to bind to. // address we requested to bind to.
...@@ -63,9 +116,103 @@ impl SharedTcpServer { ...@@ -63,9 +116,103 @@ impl SharedTcpServer {
// actual address after free port assignment (if DYN_TCP_RPC_PORT is not specified) // actual address after free port assignment (if DYN_TCP_RPC_PORT is not specified)
actual_addr: RwLock::new(None), actual_addr: RwLock::new(None),
cancellation_token, cancellation_token,
work_tx,
}) })
} }
/// Start the worker pool dispatcher that processes requests with bounded concurrency
///
/// Uses a single receiver with a semaphore to bound concurrent execution,
/// avoiding mutex contention that would serialize all workers.
fn start_worker_pool(
pool_size: usize,
mut work_rx: tokio::sync::mpsc::Receiver<WorkItem>,
cancellation_token: CancellationToken,
) {
let semaphore = Arc::new(tokio::sync::Semaphore::new(pool_size));
tokio::spawn(async move {
tracing::trace!(
"TCP worker dispatcher started with concurrency limit {}",
pool_size
);
loop {
tokio::select! {
biased;
_ = cancellation_token.cancelled() => {
tracing::trace!("TCP worker dispatcher shutting down: cancellation requested");
break;
}
msg = work_rx.recv() => {
let Some(work_item) = msg else {
tracing::trace!("TCP worker dispatcher shutting down: channel closed");
break;
};
// Acquire permit before spawning (bounds concurrency)
let permit = match semaphore.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => {
tracing::trace!("TCP worker dispatcher: semaphore closed");
break;
}
};
// Spawn task with owned permit (dropped when task completes)
tokio::spawn(async move {
Self::handle_work_item(work_item).await;
drop(permit);
});
}
}
}
tracing::trace!("TCP worker dispatcher exited");
});
tracing::info!(
"Started TCP worker dispatcher with concurrency limit {}",
pool_size
);
}
/// Handle a single work item
async fn handle_work_item(work_item: WorkItem) {
tracing::trace!(
instance_id = work_item.instance_id,
"TCP worker processing request"
);
// Create span with trace context from headers
let span = crate::logging::make_handle_payload_span_from_tcp_headers(
&work_item.headers,
&work_item.component_name,
&work_item.endpoint_name,
&work_item.namespace,
work_item.instance_id,
);
let result = work_item
.service_handler
.handle_payload(work_item.payload)
.instrument(span)
.await;
if let Err(e) = result {
tracing::warn!(
instance_id = work_item.instance_id,
error = %e,
"TCP worker failed to handle request"
);
}
work_item.inflight.fetch_sub(1, Ordering::SeqCst);
work_item.notify.notify_one();
}
/// Bind the server and start accepting connections. /// Bind the server and start accepting connections.
/// ///
/// This method binds to the configured address first, then starts the accept loop. /// This method binds to the configured address first, then starts the accept loop.
...@@ -116,8 +263,9 @@ impl SharedTcpServer { ...@@ -116,8 +263,9 @@ impl SharedTcpServer {
tracing::trace!("Accepted TCP connection from {}", peer_addr); tracing::trace!("Accepted TCP connection from {}", peer_addr);
let handlers = self.handlers.clone(); let handlers = self.handlers.clone();
let work_tx = self.work_tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = Self::handle_connection(stream, handlers).await { if let Err(e) = Self::handle_connection(stream, handlers, work_tx).await {
tracing::error!("TCP connection error: {}", e); tracing::error!("TCP connection error: {}", e);
} }
}); });
...@@ -219,6 +367,7 @@ impl SharedTcpServer { ...@@ -219,6 +367,7 @@ impl SharedTcpServer {
async fn handle_connection( async fn handle_connection(
stream: TcpStream, stream: TcpStream,
handlers: Arc<DashMap<String, Arc<EndpointHandler>>>, handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
work_tx: tokio::sync::mpsc::Sender<WorkItem>,
) -> Result<()> { ) -> Result<()> {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage}; use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage};
...@@ -232,7 +381,7 @@ impl SharedTcpServer { ...@@ -232,7 +381,7 @@ impl SharedTcpServer {
let write_task = tokio::spawn(Self::write_loop(write_half, response_rx)); let write_task = tokio::spawn(Self::write_loop(write_half, response_rx));
// Run read task in current context // Run read task in current context
let read_result = Self::read_loop(read_half, handlers, response_tx).await; let read_result = Self::read_loop(read_half, handlers, response_tx, work_tx).await;
// Write task will end when response_tx is dropped // Write task will end when response_tx is dropped
write_task.await??; write_task.await??;
...@@ -244,82 +393,40 @@ impl SharedTcpServer { ...@@ -244,82 +393,40 @@ impl SharedTcpServer {
mut read_half: tokio::io::ReadHalf<TcpStream>, mut read_half: tokio::io::ReadHalf<TcpStream>,
handlers: Arc<DashMap<String, Arc<EndpointHandler>>>, handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
response_tx: tokio::sync::mpsc::UnboundedSender<Bytes>, response_tx: tokio::sync::mpsc::UnboundedSender<Bytes>,
work_tx: tokio::sync::mpsc::Sender<WorkItem>,
) -> Result<()> { ) -> Result<()> {
use crate::pipeline::network::codec::{TcpRequestMessage, TcpResponseMessage}; use crate::pipeline::network::codec::{TcpResponseMessage, ZeroCopyTcpDecoder};
// Create zero-copy decoder with optimized buffer size
let mut decoder = ZeroCopyTcpDecoder::new();
loop { loop {
// Read endpoint path length (2 bytes) // Read one complete message with ZERO copies!
let mut path_len_buf = [0u8; 2]; let request_msg = match decoder.read_message(&mut read_half).await {
match read_half.read_exact(&mut path_len_buf).await { Ok(msg) => msg,
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
tracing::trace!("Connection closed by peer");
break; break;
} }
Err(e) => { Err(e) => {
return Err(e.into()); tracing::warn!("Failed to read TCP request: {}", e);
}
}
let path_len = u16::from_be_bytes(path_len_buf) as usize;
// Read endpoint path
let mut path_buf = vec![0u8; path_len];
read_half.read_exact(&mut path_buf).await?;
// Read headers length (2 bytes)
let mut headers_len_buf = [0u8; 2];
read_half.read_exact(&mut headers_len_buf).await?;
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
// Read headers
let mut headers_buf = vec![0u8; headers_len];
read_half.read_exact(&mut headers_buf).await?;
// Read payload length (4 bytes)
let mut len_buf = [0u8; 4];
read_half.read_exact(&mut len_buf).await?;
let payload_len = u32::from_be_bytes(len_buf) as usize;
// Sanity check - enforce maximum message size
let max_message_size = get_max_message_size();
if payload_len > max_message_size {
tracing::warn!(
"Request too large: {} bytes (max: {} bytes), closing connection",
payload_len,
max_message_size
);
// Send error response // Send error response
let error_response = let error_response =
TcpResponseMessage::new(Bytes::from_static(b"Request too large")); TcpResponseMessage::new(Bytes::from(format!("Read error: {}", e)));
if let Ok(encoded) = error_response.encode() { if let Ok(encoded) = error_response.encode() {
let _ = response_tx.send(encoded); let _ = response_tx.send(encoded);
} }
break; return Err(e.into());
} }
};
// Read request payload // Get endpoint path (zero-copy string slice)
let mut payload_buf = vec![0u8; payload_len]; let endpoint_path = match request_msg.endpoint_path() {
read_half.read_exact(&mut payload_buf).await?; Ok(path) => path,
// Reconstruct the full message buffer for decoding using BytesMut
let mut full_msg =
BytesMut::with_capacity(2 + path_len + 2 + headers_len + 4 + payload_len);
full_msg.extend_from_slice(&path_len_buf);
full_msg.extend_from_slice(&path_buf);
full_msg.extend_from_slice(&headers_len_buf);
full_msg.extend_from_slice(&headers_buf);
full_msg.extend_from_slice(&len_buf);
full_msg.extend_from_slice(&payload_buf);
// Decode using codec (zero-copy conversion)
let full_msg_bytes = full_msg.freeze();
let request_msg = match TcpRequestMessage::decode(&full_msg_bytes) {
Ok(msg) => msg,
Err(e) => { Err(e) => {
tracing::warn!("Failed to decode TCP request: {}", e); tracing::warn!("Invalid UTF-8 in endpoint path: {}", e);
// Send error response
let error_response = let error_response =
TcpResponseMessage::new(Bytes::from(format!("Decode error: {}", e))); TcpResponseMessage::new(Bytes::from_static(b"Invalid endpoint path"));
if let Ok(encoded) = error_response.encode() { if let Ok(encoded) = error_response.encode() {
let _ = response_tx.send(encoded); let _ = response_tx.send(encoded);
} }
...@@ -327,18 +434,27 @@ impl SharedTcpServer { ...@@ -327,18 +434,27 @@ impl SharedTcpServer {
} }
}; };
let endpoint_path = request_msg.endpoint_path; // Get headers (parsed from message)
let headers = request_msg.headers; let headers = request_msg.headers();
let payload = request_msg.payload;
// Get payload (zero-copy Bytes - just Arc clone!)
let payload = request_msg.payload();
tracing::trace!(
endpoint = endpoint_path,
payload_len = payload.len(),
total_size = request_msg.total_size(),
"Received TCP request"
);
// Look up handler (lock-free read with DashMap) // Look up handler (lock-free read with DashMap)
let handler = handlers.get(&endpoint_path).map(|h| h.clone()); let handler = handlers.get(endpoint_path).map(|h| h.clone());
let handler = match handler { let handler = match handler {
Some(h) => h, Some(h) => h,
None => { None => {
tracing::warn!("No handler found for endpoint: {}", endpoint_path); tracing::warn!("No handler found for endpoint: {}", endpoint_path);
// Send error response using codec // Send error response
let error_response = TcpResponseMessage::new(Bytes::from(format!( let error_response = TcpResponseMessage::new(Bytes::from(format!(
"Unknown endpoint: {}", "Unknown endpoint: {}",
endpoint_path endpoint_path
...@@ -352,54 +468,67 @@ impl SharedTcpServer { ...@@ -352,54 +468,67 @@ impl SharedTcpServer {
handler.inflight.fetch_add(1, Ordering::SeqCst); handler.inflight.fetch_add(1, Ordering::SeqCst);
// Send acknowledgment immediately using codec (non-blocking, zero-copy) // Build work item
// NOTE: payload is Bytes (Arc-counted), so cloning is extremely cheap
let work_item = WorkItem {
service_handler: handler.service_handler.clone(),
payload,
headers,
inflight: handler.inflight.clone(),
notify: handler.notify.clone(),
instance_id: handler.instance_id,
namespace: handler.namespace.clone(),
component_name: handler.component_name.clone(),
endpoint_name: handler.endpoint_name.clone(),
};
// Send to worker pool with backpressure - BEFORE sending ACK
match work_tx.send(work_item).await {
Ok(_) => {
// Send acknowledgment ONLY after successful queuing
let ack_response = TcpResponseMessage::empty(); let ack_response = TcpResponseMessage::empty();
if let Ok(encoded_ack) = ack_response.encode() { if let Ok(encoded_ack) = ack_response.encode()
// Send to write task without blocking reads && response_tx.send(encoded_ack).is_err()
if response_tx.send(encoded_ack).is_err() { {
tracing::debug!("Write task closed, ending read loop"); tracing::debug!("Write task closed, ending read loop");
// Clean up inflight counter since work was queued but ACK failed
handler.inflight.fetch_sub(1, Ordering::SeqCst);
handler.notify.notify_one();
break; break;
} }
}
// Process request asynchronously
let service_handler = handler.service_handler.clone();
let inflight = handler.inflight.clone();
let notify = handler.notify.clone();
let instance_id = handler.instance_id;
let namespace = handler.namespace.clone();
let component_name = handler.component_name.clone();
let endpoint_name = handler.endpoint_name.clone();
tokio::spawn(async move {
tracing::trace!(instance_id, "handling TCP request");
// Create span with trace context from headers tracing::trace!(
let span = crate::logging::make_handle_payload_span_from_tcp_headers( endpoint = handler.endpoint_name.as_str(),
&headers, instance_id = handler.instance_id,
&component_name, "Request queued and acknowledged"
&endpoint_name, );
&namespace, }
instance_id, Err(e) => {
tracing::warn!(
endpoint = handler.endpoint_name.as_str(),
instance_id = handler.instance_id,
error = %e,
"Failed to queue work to worker pool, sending error response"
); );
let result = service_handler // Send error response to client instead of ACK
.handle_payload(payload) let error_response =
.instrument(span) TcpResponseMessage::new(Bytes::from(format!("Server overloaded: {}", e)));
.await; if let Ok(encoded) = error_response.encode() {
let _ = response_tx.send(encoded);
}
match result { // Clean up inflight counter
Ok(_) => { handler.inflight.fetch_sub(1, Ordering::SeqCst);
tracing::trace!(instance_id, "TCP request handled successfully"); handler.notify.notify_one();
// If channel is closed, break the loop
if matches!(e, tokio::sync::mpsc::error::SendError(_)) {
tracing::error!("Worker pool channel closed, shutting down read loop");
break;
} }
Err(e) => {
tracing::warn!("Failed to handle TCP request: {}", e);
} }
} }
inflight.fetch_sub(1, Ordering::SeqCst);
notify.notify_one();
});
} }
Ok(()) Ok(())
...@@ -692,4 +821,138 @@ mod tests { ...@@ -692,4 +821,138 @@ mod tests {
tracing::info!("Test passed: unregister_endpoint properly waited for inflight TCP request"); tracing::info!("Test passed: unregister_endpoint properly waited for inflight TCP request");
} }
///////////////////// TESTS FOR CONCURRENCY BOUNDING /////////////////////
/// Mock handler that tracks concurrent execution count
struct ConcurrencyTrackingHandler {
/// Current number of concurrent requests being processed
concurrent_count: Arc<AtomicU64>,
/// Maximum concurrent count observed
max_concurrent: Arc<AtomicU64>,
/// Duration to simulate request processing
processing_duration: Duration,
/// Notifies when a request completes
completed: Arc<Notify>,
}
impl ConcurrencyTrackingHandler {
fn new(processing_duration: Duration) -> Self {
Self {
concurrent_count: Arc::new(AtomicU64::new(0)),
max_concurrent: Arc::new(AtomicU64::new(0)),
processing_duration,
completed: Arc::new(Notify::new()),
}
}
}
#[async_trait]
impl PushWorkHandler for ConcurrencyTrackingHandler {
async fn handle_payload(&self, _payload: Bytes) -> Result<(), PipelineError> {
// Increment concurrent count
let current = self.concurrent_count.fetch_add(1, Ordering::SeqCst) + 1;
// Update max if this is higher
self.max_concurrent.fetch_max(current, Ordering::SeqCst);
// Simulate work
tokio::time::sleep(self.processing_duration).await;
// Decrement concurrent count
self.concurrent_count.fetch_sub(1, Ordering::SeqCst);
self.completed.notify_one();
Ok(())
}
fn add_metrics(
&self,
_endpoint: &crate::component::Endpoint,
_metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
Ok(())
}
}
#[tokio::test]
async fn test_worker_pool_bounds_concurrency() {
let _ = tracing_subscriber::fmt()
.with_test_writer()
.with_max_level(tracing::Level::DEBUG)
.try_init();
// Use a small pool size for testing
let pool_size = 3;
let total_requests = 10;
// Create bounded channel and dispatcher directly
let (work_tx, work_rx) = tokio::sync::mpsc::channel::<WorkItem>(total_requests);
let cancellation_token = CancellationToken::new();
// Start worker pool with small concurrency limit
SharedTcpServer::start_worker_pool(pool_size, work_rx, cancellation_token.clone());
// Create tracking handler
let handler = Arc::new(ConcurrencyTrackingHandler::new(Duration::from_millis(50)));
// Create dummy inflight/notify for work items
let inflight = Arc::new(AtomicU64::new(0));
let notify = Arc::new(Notify::new());
// Send more work items than pool size
for i in 0..total_requests {
inflight.fetch_add(1, Ordering::SeqCst);
let work_item = WorkItem {
service_handler: handler.clone() as Arc<dyn PushWorkHandler>,
payload: Bytes::from(format!("request {}", i)),
headers: std::collections::HashMap::new(),
inflight: inflight.clone(),
notify: notify.clone(),
instance_id: 1,
namespace: "test".to_string(),
component_name: "test".to_string(),
endpoint_name: "test".to_string(),
};
work_tx.send(work_item).await.expect("send should succeed");
}
// Wait for all requests to complete
let timeout = tokio::time::timeout(Duration::from_secs(5), async {
while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
}
})
.await;
assert!(
timeout.is_ok(),
"All requests should complete within timeout"
);
// Verify concurrency was bounded
let max_observed = handler.max_concurrent.load(Ordering::SeqCst);
assert!(
max_observed <= pool_size as u64,
"Max concurrent ({}) should not exceed pool size ({})",
max_observed,
pool_size
);
// Verify all requests completed
assert_eq!(
inflight.load(Ordering::SeqCst),
0,
"All requests should have completed"
);
tracing::info!(
"Test passed: max concurrent {} <= pool size {}",
max_observed,
pool_size
);
// Cleanup
cancellation_token.cancel();
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment