Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
bbb79afb
Unverified
Commit
bbb79afb
authored
Jan 13, 2026
by
Biswa Panda
Committed by
GitHub
Jan 13, 2026
Browse files
fix: use zero copy decoder for handling high concurrency / request bursts for tcp ingress (#5376)
parent
1da603a4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
882 additions
and
114 deletions
+882
-114
lib/runtime/src/pipeline/network/codec.rs
lib/runtime/src/pipeline/network/codec.rs
+2
-0
lib/runtime/src/pipeline/network/codec/zero_copy_decoder.rs
lib/runtime/src/pipeline/network/codec/zero_copy_decoder.rs
+503
-0
lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs
...ntime/src/pipeline/network/ingress/shared_tcp_endpoint.rs
+377
-114
No files found.
lib/runtime/src/pipeline/network/codec.rs
View file @
bbb79afb
...
...
@@ -15,8 +15,10 @@ use tokio_util::{
};
mod
two_part
;
pub
mod
zero_copy_decoder
;
pub
use
two_part
::{
TwoPartCodec
,
TwoPartMessage
,
TwoPartMessageType
};
pub
use
zero_copy_decoder
::{
TcpRequestMessageZeroCopy
,
ZeroCopyTcpDecoder
};
/// TCP request plane protocol message with endpoint routing and trace headers
///
...
...
lib/runtime/src/pipeline/network/codec/zero_copy_decoder.rs
0 → 100644
View file @
bbb79afb
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Zero-copy TCP message decoder for high-concurrency scenarios
//!
//! This decoder eliminates message reconstruction copies by:
//! 1. Reading into a reusable buffer
//! 2. Parsing headers in-place
//! 3. Splitting off exact message sizes (zero-copy via Bytes::split_to)
//! 4. Returning Arc-counted Bytes that can be cloned cheaply
use
bytes
::{
Buf
,
Bytes
,
BytesMut
};
use
std
::
io
;
use
tokio
::
io
::{
AsyncRead
,
AsyncReadExt
};
/// Maximum message size (32MB default, configurable via env)
const
MAX_MESSAGE_SIZE
:
usize
=
32
*
1024
*
1024
;
// 32MB
const
INITIAL_BUFFER_SIZE
:
usize
=
262144
;
// 256KB
fn
get_max_message_size
()
->
usize
{
std
::
env
::
var
(
"DYN_TCP_MAX_MESSAGE_SIZE"
)
.ok
()
.and_then
(|
s
|
s
.parse
::
<
usize
>
()
.ok
())
.unwrap_or
(
MAX_MESSAGE_SIZE
)
}
/// Zero-copy streaming decoder that reuses buffers
///
/// This decoder maintains an internal buffer and only allocates when necessary.
/// Messages are returned as Arc-counted Bytes slices, making cloning extremely cheap.
pub
struct
ZeroCopyTcpDecoder
{
/// Reusable read buffer - grows as needed but never shrinks
read_buffer
:
BytesMut
,
/// Maximum allowed message size
max_message_size
:
usize
,
}
impl
ZeroCopyTcpDecoder
{
/// Create a new decoder with default buffer size
pub
fn
new
()
->
Self
{
Self
::
with_capacity
(
INITIAL_BUFFER_SIZE
)
}
/// Create a new decoder with specific initial capacity
pub
fn
with_capacity
(
capacity
:
usize
)
->
Self
{
Self
{
read_buffer
:
BytesMut
::
with_capacity
(
capacity
),
max_message_size
:
get_max_message_size
(),
}
}
/// Read one complete message with ZERO copies
///
/// This method:
/// 1. Ensures headers are buffered
/// 2. Parses headers in-place (no allocation)
/// 3. Ensures entire message is buffered
/// 4. Splits off exact message size (zero-copy pointer arithmetic)
/// 5. Returns Arc-counted Bytes (cheap to clone)
pub
async
fn
read_message
<
R
:
AsyncRead
+
Unpin
>
(
&
mut
self
,
reader
:
&
mut
R
,
)
->
io
::
Result
<
TcpRequestMessageZeroCopy
>
{
// Ensure we have at least enough bytes to start parsing
// Wire format: [path_len(2)][path][headers_len(2)][headers][payload_len(4)][payload]
const
MIN_HEADER_SIZE
:
usize
=
2
;
// Fill buffer if needed
while
self
.read_buffer
.len
()
<
MIN_HEADER_SIZE
{
let
n
=
reader
.read_buf
(
&
mut
self
.read_buffer
)
.await
?
;
if
n
==
0
{
if
self
.read_buffer
.is_empty
()
{
return
Err
(
io
::
Error
::
new
(
io
::
ErrorKind
::
UnexpectedEof
,
"connection closed"
,
));
}
else
{
return
Err
(
io
::
Error
::
new
(
io
::
ErrorKind
::
UnexpectedEof
,
"incomplete message header"
,
));
}
}
}
// Parse endpoint path length (first 2 bytes) - NO COPY
let
path_len
=
u16
::
from_be_bytes
([
self
.read_buffer
[
0
],
self
.read_buffer
[
1
]])
as
usize
;
// Sanity check path length
if
path_len
==
0
||
path_len
>
1024
{
return
Err
(
io
::
Error
::
new
(
io
::
ErrorKind
::
InvalidData
,
format!
(
"invalid endpoint path length: {}"
,
path_len
),
));
}
// Ensure we have path + headers_len
let
initial_header_size
=
2
+
path_len
+
2
;
// path_len(2) + path + headers_len(2)
while
self
.read_buffer
.len
()
<
initial_header_size
{
let
n
=
reader
.read_buf
(
&
mut
self
.read_buffer
)
.await
?
;
if
n
==
0
{
return
Err
(
io
::
Error
::
new
(
io
::
ErrorKind
::
UnexpectedEof
,
"incomplete message header"
,
));
}
}
// Parse headers length (2 bytes after path) - NO COPY
let
headers_len_offset
=
2
+
path_len
;
let
headers_len
=
u16
::
from_be_bytes
([
self
.read_buffer
[
headers_len_offset
],
self
.read_buffer
[
headers_len_offset
+
1
],
])
as
usize
;
// Ensure we have headers + payload length
let
full_header_size
=
2
+
path_len
+
2
+
headers_len
+
4
;
// path_len(2) + path + headers_len(2) + headers + payload_len(4)
while
self
.read_buffer
.len
()
<
full_header_size
{
let
n
=
reader
.read_buf
(
&
mut
self
.read_buffer
)
.await
?
;
if
n
==
0
{
return
Err
(
io
::
Error
::
new
(
io
::
ErrorKind
::
UnexpectedEof
,
"incomplete message header"
,
));
}
}
// Parse payload length (4 bytes after headers) - NO COPY
let
payload_len_offset
=
2
+
path_len
+
2
+
headers_len
;
let
payload_len
=
u32
::
from_be_bytes
([
self
.read_buffer
[
payload_len_offset
],
self
.read_buffer
[
payload_len_offset
+
1
],
self
.read_buffer
[
payload_len_offset
+
2
],
self
.read_buffer
[
payload_len_offset
+
3
],
])
as
usize
;
// Calculate total message size
let
total_len
=
2
+
path_len
+
2
+
headers_len
+
4
+
payload_len
;
// Sanity check total message length (including all overhead)
if
total_len
>
self
.max_message_size
{
return
Err
(
io
::
Error
::
new
(
io
::
ErrorKind
::
InvalidData
,
format!
(
"message too large: {} bytes (max: {} bytes)"
,
total_len
,
self
.max_message_size
),
));
}
// Ensure entire message is buffered
while
self
.read_buffer
.len
()
<
total_len
{
let
n
=
reader
.read_buf
(
&
mut
self
.read_buffer
)
.await
?
;
if
n
==
0
{
return
Err
(
io
::
Error
::
new
(
io
::
ErrorKind
::
UnexpectedEof
,
format!
(
"incomplete message: expected {} bytes, got {}"
,
total_len
,
self
.read_buffer
.len
()
),
));
}
}
// Split off exactly what we need - ZERO COPY!
// split_to() just advances the internal pointer, doesn't allocate or copy
let
message_bytes
=
self
.read_buffer
.split_to
(
total_len
)
.freeze
();
// Return zero-copy message wrapper
Ok
(
TcpRequestMessageZeroCopy
::
new
(
message_bytes
))
}
/// Get the current buffer capacity
pub
fn
buffer_capacity
(
&
self
)
->
usize
{
self
.read_buffer
.capacity
()
}
/// Get the current buffered data size
pub
fn
buffered_len
(
&
self
)
->
usize
{
self
.read_buffer
.len
()
}
}
impl
Default
for
ZeroCopyTcpDecoder
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
/// Zero-copy message representation
///
/// This struct holds an Arc-counted Bytes buffer containing the entire message.
/// All accessors return zero-copy slices or references into this buffer.
#[derive(Clone)]
pub
struct
TcpRequestMessageZeroCopy
{
/// Entire message as Arc-counted buffer
/// Format: [path_len(2)][path(var)][headers_len(2)][headers(var)][payload_len(4)][payload(var)]
raw
:
Bytes
,
}
impl
TcpRequestMessageZeroCopy
{
/// Create a new zero-copy message from raw bytes
fn
new
(
raw
:
Bytes
)
->
Self
{
Self
{
raw
}
}
/// Get the endpoint path length
#[inline]
fn
path_len
(
&
self
)
->
usize
{
u16
::
from_be_bytes
([
self
.raw
[
0
],
self
.raw
[
1
]])
as
usize
}
/// Get endpoint path as a string slice (zero-copy)
///
/// This returns a reference into the message buffer, no allocation.
pub
fn
endpoint_path
(
&
self
)
->
Result
<&
str
,
std
::
str
::
Utf8Error
>
{
let
path_len
=
self
.path_len
();
std
::
str
::
from_utf8
(
&
self
.raw
[
2
..
2
+
path_len
])
}
/// Get endpoint path as bytes (zero-copy)
pub
fn
endpoint_path_bytes
(
&
self
)
->
&
[
u8
]
{
let
path_len
=
self
.path_len
();
&
self
.raw
[
2
..
2
+
path_len
]
}
/// Get the headers length
#[inline]
fn
headers_len
(
&
self
)
->
usize
{
let
path_len
=
self
.path_len
();
let
offset
=
2
+
path_len
;
u16
::
from_be_bytes
([
self
.raw
[
offset
],
self
.raw
[
offset
+
1
]])
as
usize
}
/// Get headers as bytes (zero-copy)
pub
fn
headers_bytes
(
&
self
)
->
&
[
u8
]
{
let
path_len
=
self
.path_len
();
let
headers_len
=
self
.headers_len
();
let
headers_start
=
2
+
path_len
+
2
;
&
self
.raw
[
headers_start
..
headers_start
+
headers_len
]
}
/// Get headers as a HashMap (requires parsing)
pub
fn
headers
(
&
self
)
->
std
::
collections
::
HashMap
<
String
,
String
>
{
let
headers_bytes
=
self
.headers_bytes
();
if
headers_bytes
.is_empty
()
{
return
std
::
collections
::
HashMap
::
new
();
}
// Parse headers from JSON format
serde_json
::
from_slice
(
headers_bytes
)
.unwrap_or_default
()
}
/// Get the payload length
#[inline]
fn
payload_len
(
&
self
)
->
usize
{
let
path_len
=
self
.path_len
();
let
headers_len
=
self
.headers_len
();
let
offset
=
2
+
path_len
+
2
+
headers_len
;
u32
::
from_be_bytes
([
self
.raw
[
offset
],
self
.raw
[
offset
+
1
],
self
.raw
[
offset
+
2
],
self
.raw
[
offset
+
3
],
])
as
usize
}
/// Get payload as zero-copy Bytes
///
/// This returns an Arc-counted slice of the message buffer.
/// Cloning the returned Bytes is extremely cheap (just Arc clone).
pub
fn
payload
(
&
self
)
->
Bytes
{
let
path_len
=
self
.path_len
();
let
headers_len
=
self
.headers_len
();
let
payload_start
=
2
+
path_len
+
2
+
headers_len
+
4
;
self
.raw
.slice
(
payload_start
..
)
// ZERO COPY! Just Arc clone + offset
}
/// Get total message size in bytes
pub
fn
total_size
(
&
self
)
->
usize
{
self
.raw
.len
()
}
/// Get the raw message bytes (for debugging)
pub
fn
raw_bytes
(
&
self
)
->
&
Bytes
{
&
self
.raw
}
}
impl
std
::
fmt
::
Debug
for
TcpRequestMessageZeroCopy
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"TcpRequestMessageZeroCopy"
)
.field
(
"total_size"
,
&
self
.total_size
())
.field
(
"endpoint_path"
,
&
self
.endpoint_path
()
.ok
())
.field
(
"payload_len"
,
&
self
.payload_len
())
.finish
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
tokio
::
io
::
AsyncWriteExt
;
#[tokio::test]
async
fn
test_zero_copy_decoder_basic
()
{
// Create a test message with headers
let
endpoint
=
"test/endpoint"
;
let
payload
=
b
"Hello, World!"
;
let
headers
:
Vec
<
u8
>
=
vec!
[];
// Empty headers
let
mut
message
=
Vec
::
new
();
// path_len + path
message
.extend_from_slice
(
&
(
endpoint
.len
()
as
u16
)
.to_be_bytes
());
message
.extend_from_slice
(
endpoint
.as_bytes
());
// headers_len + headers
message
.extend_from_slice
(
&
(
headers
.len
()
as
u16
)
.to_be_bytes
());
message
.extend_from_slice
(
&
headers
);
// payload_len + payload
message
.extend_from_slice
(
&
(
payload
.len
()
as
u32
)
.to_be_bytes
());
message
.extend_from_slice
(
payload
);
// Create a mock reader
let
mut
reader
=
&
message
[
..
];
// Decode
let
mut
decoder
=
ZeroCopyTcpDecoder
::
new
();
let
msg
=
decoder
.read_message
(
&
mut
reader
)
.await
.unwrap
();
// Verify
assert_eq!
(
msg
.endpoint_path
()
.unwrap
(),
endpoint
);
assert_eq!
(
msg
.payload
()
.as_ref
(),
payload
);
assert_eq!
(
msg
.total_size
(),
message
.len
());
assert_eq!
(
msg
.headers
()
.len
(),
0
);
// Empty headers
}
#[tokio::test]
async
fn
test_zero_copy_decoder_large_payload
()
{
// Create a large payload (200KB)
let
endpoint
=
"large/endpoint"
;
let
payload
=
vec!
[
0x42u8
;
200
*
1024
];
let
headers
:
Vec
<
u8
>
=
vec!
[];
// Empty headers
let
mut
message
=
Vec
::
new
();
// path_len + path
message
.extend_from_slice
(
&
(
endpoint
.len
()
as
u16
)
.to_be_bytes
());
message
.extend_from_slice
(
endpoint
.as_bytes
());
// headers_len + headers
message
.extend_from_slice
(
&
(
headers
.len
()
as
u16
)
.to_be_bytes
());
message
.extend_from_slice
(
&
headers
);
// payload_len + payload
message
.extend_from_slice
(
&
(
payload
.len
()
as
u32
)
.to_be_bytes
());
message
.extend_from_slice
(
&
payload
);
let
mut
reader
=
&
message
[
..
];
let
mut
decoder
=
ZeroCopyTcpDecoder
::
new
();
let
msg
=
decoder
.read_message
(
&
mut
reader
)
.await
.unwrap
();
assert_eq!
(
msg
.endpoint_path
()
.unwrap
(),
endpoint
);
assert_eq!
(
msg
.payload
()
.len
(),
payload
.len
());
}
#[tokio::test]
async
fn
test_zero_copy_decoder_total_size_limit
()
{
// Test that the decoder validates total message size, not just payload size
// Create a message where total_len exceeds max but payload alone might not
let
max_size
=
1024
;
// 1KB limit
let
mut
decoder
=
ZeroCopyTcpDecoder
::
with_capacity
(
256
);
decoder
.max_message_size
=
max_size
;
// Create a message that exceeds the limit with overhead included
let
endpoint
=
"test/endpoint"
;
let
payload
=
vec!
[
0x42u8
;
max_size
];
// Payload equals max
let
headers
:
Vec
<
u8
>
=
vec!
[];
// Empty headers
let
mut
message
=
Vec
::
new
();
// path_len + path
message
.extend_from_slice
(
&
(
endpoint
.len
()
as
u16
)
.to_be_bytes
());
message
.extend_from_slice
(
endpoint
.as_bytes
());
// headers_len + headers
message
.extend_from_slice
(
&
(
headers
.len
()
as
u16
)
.to_be_bytes
());
message
.extend_from_slice
(
&
headers
);
// payload_len + payload
message
.extend_from_slice
(
&
(
payload
.len
()
as
u32
)
.to_be_bytes
());
message
.extend_from_slice
(
&
payload
);
// total_len = 2 + 13 + 2 + 0 + 4 + 1024 = 1045 bytes > 1024 max
let
mut
reader
=
&
message
[
..
];
let
result
=
decoder
.read_message
(
&
mut
reader
)
.await
;
// Should fail with InvalidData error
assert
!
(
result
.is_err
());
let
err
=
result
.unwrap_err
();
assert_eq!
(
err
.kind
(),
io
::
ErrorKind
::
InvalidData
);
assert
!
(
err
.to_string
()
.contains
(
"message too large"
));
assert
!
(
err
.to_string
()
.contains
(
"1045"
));
// total_len
assert
!
(
err
.to_string
()
.contains
(
"1024"
));
// max_message_size
}
#[tokio::test]
async
fn
test_zero_copy_decoder_with_headers
()
{
// Test header parsing with actual header data
let
endpoint
=
"api/v1/inference"
;
let
payload
=
b
"Request payload data"
;
// Create mock headers as JSON
let
mut
headers_map
=
std
::
collections
::
HashMap
::
new
();
headers_map
.insert
(
"traceparent"
.to_string
(),
"00-abc123-def456-01"
.to_string
());
headers_map
.insert
(
"user-agent"
.to_string
(),
"test-client/1.0"
.to_string
());
headers_map
.insert
(
"request-id"
.to_string
(),
"req-12345"
.to_string
());
let
headers_json
=
serde_json
::
to_vec
(
&
headers_map
)
.unwrap
();
let
mut
message
=
Vec
::
new
();
// path_len + path
message
.extend_from_slice
(
&
(
endpoint
.len
()
as
u16
)
.to_be_bytes
());
message
.extend_from_slice
(
endpoint
.as_bytes
());
// headers_len + headers (non-empty this time)
message
.extend_from_slice
(
&
(
headers_json
.len
()
as
u16
)
.to_be_bytes
());
message
.extend_from_slice
(
&
headers_json
);
// payload_len + payload
message
.extend_from_slice
(
&
(
payload
.len
()
as
u32
)
.to_be_bytes
());
message
.extend_from_slice
(
payload
);
// Decode the message
let
mut
reader
=
&
message
[
..
];
let
mut
decoder
=
ZeroCopyTcpDecoder
::
new
();
let
msg
=
decoder
.read_message
(
&
mut
reader
)
.await
.unwrap
();
// Verify endpoint
assert_eq!
(
msg
.endpoint_path
()
.unwrap
(),
endpoint
);
// Verify payload
assert_eq!
(
msg
.payload
()
.as_ref
(),
payload
);
// Verify total size includes all components
assert_eq!
(
msg
.total_size
(),
message
.len
());
// Verify headers are correctly parsed
let
decoded_headers
=
msg
.headers
();
assert_eq!
(
decoded_headers
.len
(),
3
);
assert_eq!
(
decoded_headers
.get
(
"traceparent"
)
.unwrap
(),
"00-abc123-def456-01"
);
assert_eq!
(
decoded_headers
.get
(
"user-agent"
)
.unwrap
(),
"test-client/1.0"
);
assert_eq!
(
decoded_headers
.get
(
"request-id"
)
.unwrap
(),
"req-12345"
);
// Verify headers_bytes returns the raw JSON
let
headers_bytes
=
msg
.headers_bytes
();
assert_eq!
(
headers_bytes
,
&
headers_json
[
..
]);
}
#[tokio::test]
async
fn
test_zero_copy_decoder_empty_vs_populated_headers
()
{
// Test both empty and populated headers in sequence to ensure proper parsing
let
endpoint
=
"test/endpoint"
;
let
payload
=
b
"test data"
;
// Test 1: Empty headers
let
mut
message_empty
=
Vec
::
new
();
message_empty
.extend_from_slice
(
&
(
endpoint
.len
()
as
u16
)
.to_be_bytes
());
message_empty
.extend_from_slice
(
endpoint
.as_bytes
());
message_empty
.extend_from_slice
(
&
(
0u16
)
.to_be_bytes
());
// headers_len = 0
// No headers bytes
message_empty
.extend_from_slice
(
&
(
payload
.len
()
as
u32
)
.to_be_bytes
());
message_empty
.extend_from_slice
(
payload
);
let
mut
reader
=
&
message_empty
[
..
];
let
mut
decoder
=
ZeroCopyTcpDecoder
::
new
();
let
msg
=
decoder
.read_message
(
&
mut
reader
)
.await
.unwrap
();
assert_eq!
(
msg
.endpoint_path
()
.unwrap
(),
endpoint
);
assert_eq!
(
msg
.payload
()
.as_ref
(),
payload
);
assert_eq!
(
msg
.headers
()
.len
(),
0
);
assert_eq!
(
msg
.headers_bytes
()
.len
(),
0
);
// Test 2: Populated headers with same decoder
let
mut
headers_map
=
std
::
collections
::
HashMap
::
new
();
headers_map
.insert
(
"x-test-header"
.to_string
(),
"test-value"
.to_string
());
let
headers_json
=
serde_json
::
to_vec
(
&
headers_map
)
.unwrap
();
let
mut
message_with_headers
=
Vec
::
new
();
message_with_headers
.extend_from_slice
(
&
(
endpoint
.len
()
as
u16
)
.to_be_bytes
());
message_with_headers
.extend_from_slice
(
endpoint
.as_bytes
());
message_with_headers
.extend_from_slice
(
&
(
headers_json
.len
()
as
u16
)
.to_be_bytes
());
message_with_headers
.extend_from_slice
(
&
headers_json
);
message_with_headers
.extend_from_slice
(
&
(
payload
.len
()
as
u32
)
.to_be_bytes
());
message_with_headers
.extend_from_slice
(
payload
);
let
mut
reader
=
&
message_with_headers
[
..
];
let
msg
=
decoder
.read_message
(
&
mut
reader
)
.await
.unwrap
();
assert_eq!
(
msg
.endpoint_path
()
.unwrap
(),
endpoint
);
assert_eq!
(
msg
.payload
()
.as_ref
(),
payload
);
assert_eq!
(
msg
.headers
()
.len
(),
1
);
assert_eq!
(
msg
.headers
()
.get
(
"x-test-header"
)
.unwrap
(),
"test-value"
);
}
}
lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs
View file @
bbb79afb
...
...
@@ -25,6 +25,13 @@ use tracing::Instrument;
/// Default maximum message size for TCP server (32 MB)
const
DEFAULT_MAX_MESSAGE_SIZE
:
usize
=
32
*
1024
*
1024
;
/// Default worker pool size for TCP request handling
const
DEFAULT_WORKER_POOL_SIZE
:
usize
=
1500
;
/// Default work queue size for TCP request handling
/// this is 4X the worker pool size to handle burst traffic
const
DEFAULT_WORK_QUEUE_SIZE
:
usize
=
6000
;
/// Get maximum message size from environment or use default
fn
get_max_message_size
()
->
usize
{
std
::
env
::
var
(
"DYN_TCP_MAX_MESSAGE_SIZE"
)
...
...
@@ -33,6 +40,35 @@ fn get_max_message_size() -> usize {
.unwrap_or
(
DEFAULT_MAX_MESSAGE_SIZE
)
}
/// Get worker pool size from environment or use default
fn
get_worker_pool_size
()
->
usize
{
std
::
env
::
var
(
"DYN_TCP_WORKER_POOL_SIZE"
)
.ok
()
.and_then
(|
s
|
s
.parse
::
<
usize
>
()
.ok
())
.unwrap_or
(
DEFAULT_WORKER_POOL_SIZE
)
}
/// Get work queue size from environment or use default
fn
get_work_queue_size
()
->
usize
{
std
::
env
::
var
(
"DYN_TCP_WORK_QUEUE_SIZE"
)
.ok
()
.and_then
(|
s
|
s
.parse
::
<
usize
>
()
.ok
())
.unwrap_or
(
DEFAULT_WORK_QUEUE_SIZE
)
}
/// Work item for the worker pool
struct
WorkItem
{
service_handler
:
Arc
<
dyn
PushWorkHandler
>
,
payload
:
Bytes
,
headers
:
std
::
collections
::
HashMap
<
String
,
String
>
,
inflight
:
Arc
<
AtomicU64
>
,
notify
:
Arc
<
Notify
>
,
instance_id
:
u64
,
namespace
:
String
,
component_name
:
String
,
endpoint_name
:
String
,
}
/// Shared TCP server that handles multiple endpoints on a single port
pub
struct
SharedTcpServer
{
handlers
:
Arc
<
DashMap
<
String
,
Arc
<
EndpointHandler
>>>
,
...
...
@@ -41,6 +77,8 @@ pub struct SharedTcpServer {
/// The actual bound address (populated after bind_and_start, contains actual port)
actual_addr
:
RwLock
<
Option
<
SocketAddr
>>
,
cancellation_token
:
CancellationToken
,
/// Channel for sending work to the worker pool
work_tx
:
tokio
::
sync
::
mpsc
::
Sender
<
WorkItem
>
,
}
struct
EndpointHandler
{
...
...
@@ -56,6 +94,21 @@ struct EndpointHandler {
impl
SharedTcpServer
{
pub
fn
new
(
bind_addr
:
SocketAddr
,
cancellation_token
:
CancellationToken
)
->
Arc
<
Self
>
{
let
worker_pool_size
=
get_worker_pool_size
();
let
work_queue_size
=
get_work_queue_size
();
tracing
::
info!
(
"Initializing TCP server with dispatcher (concurrency={}, queue={})"
,
worker_pool_size
,
work_queue_size
);
// Create bounded channel for work items
let
(
work_tx
,
work_rx
)
=
tokio
::
sync
::
mpsc
::
channel
(
work_queue_size
);
// Start worker pool
Self
::
start_worker_pool
(
worker_pool_size
,
work_rx
,
cancellation_token
.clone
());
Arc
::
new
(
Self
{
handlers
:
Arc
::
new
(
DashMap
::
new
()),
// address we requested to bind to.
...
...
@@ -63,9 +116,103 @@ impl SharedTcpServer {
// actual address after free port assignment (if DYN_TCP_RPC_PORT is not specified)
actual_addr
:
RwLock
::
new
(
None
),
cancellation_token
,
work_tx
,
})
}
/// Start the worker pool dispatcher that processes requests with bounded concurrency
///
/// Uses a single receiver with a semaphore to bound concurrent execution,
/// avoiding mutex contention that would serialize all workers.
fn
start_worker_pool
(
pool_size
:
usize
,
mut
work_rx
:
tokio
::
sync
::
mpsc
::
Receiver
<
WorkItem
>
,
cancellation_token
:
CancellationToken
,
)
{
let
semaphore
=
Arc
::
new
(
tokio
::
sync
::
Semaphore
::
new
(
pool_size
));
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"TCP worker dispatcher started with concurrency limit {}"
,
pool_size
);
loop
{
tokio
::
select!
{
biased
;
_
=
cancellation_token
.cancelled
()
=>
{
tracing
::
trace!
(
"TCP worker dispatcher shutting down: cancellation requested"
);
break
;
}
msg
=
work_rx
.recv
()
=>
{
let
Some
(
work_item
)
=
msg
else
{
tracing
::
trace!
(
"TCP worker dispatcher shutting down: channel closed"
);
break
;
};
// Acquire permit before spawning (bounds concurrency)
let
permit
=
match
semaphore
.clone
()
.acquire_owned
()
.await
{
Ok
(
p
)
=>
p
,
Err
(
_
)
=>
{
tracing
::
trace!
(
"TCP worker dispatcher: semaphore closed"
);
break
;
}
};
// Spawn task with owned permit (dropped when task completes)
tokio
::
spawn
(
async
move
{
Self
::
handle_work_item
(
work_item
)
.await
;
drop
(
permit
);
});
}
}
}
tracing
::
trace!
(
"TCP worker dispatcher exited"
);
});
tracing
::
info!
(
"Started TCP worker dispatcher with concurrency limit {}"
,
pool_size
);
}
/// Handle a single work item
async
fn
handle_work_item
(
work_item
:
WorkItem
)
{
tracing
::
trace!
(
instance_id
=
work_item
.instance_id
,
"TCP worker processing request"
);
// Create span with trace context from headers
let
span
=
crate
::
logging
::
make_handle_payload_span_from_tcp_headers
(
&
work_item
.headers
,
&
work_item
.component_name
,
&
work_item
.endpoint_name
,
&
work_item
.namespace
,
work_item
.instance_id
,
);
let
result
=
work_item
.service_handler
.handle_payload
(
work_item
.payload
)
.instrument
(
span
)
.await
;
if
let
Err
(
e
)
=
result
{
tracing
::
warn!
(
instance_id
=
work_item
.instance_id
,
error
=
%
e
,
"TCP worker failed to handle request"
);
}
work_item
.inflight
.fetch_sub
(
1
,
Ordering
::
SeqCst
);
work_item
.notify
.notify_one
();
}
/// Bind the server and start accepting connections.
///
/// This method binds to the configured address first, then starts the accept loop.
...
...
@@ -116,8 +263,9 @@ impl SharedTcpServer {
tracing
::
trace!
(
"Accepted TCP connection from {}"
,
peer_addr
);
let
handlers
=
self
.handlers
.clone
();
let
work_tx
=
self
.work_tx
.clone
();
tokio
::
spawn
(
async
move
{
if
let
Err
(
e
)
=
Self
::
handle_connection
(
stream
,
handlers
)
.await
{
if
let
Err
(
e
)
=
Self
::
handle_connection
(
stream
,
handlers
,
work_tx
)
.await
{
tracing
::
error!
(
"TCP connection error: {}"
,
e
);
}
});
...
...
@@ -219,6 +367,7 @@ impl SharedTcpServer {
async
fn
handle_connection
(
stream
:
TcpStream
,
handlers
:
Arc
<
DashMap
<
String
,
Arc
<
EndpointHandler
>>>
,
work_tx
:
tokio
::
sync
::
mpsc
::
Sender
<
WorkItem
>
,
)
->
Result
<
()
>
{
use
crate
::
pipeline
::
network
::
codec
::{
TcpRequestMessage
,
TcpResponseMessage
};
...
...
@@ -232,7 +381,7 @@ impl SharedTcpServer {
let
write_task
=
tokio
::
spawn
(
Self
::
write_loop
(
write_half
,
response_rx
));
// Run read task in current context
let
read_result
=
Self
::
read_loop
(
read_half
,
handlers
,
response_tx
)
.await
;
let
read_result
=
Self
::
read_loop
(
read_half
,
handlers
,
response_tx
,
work_tx
)
.await
;
// Write task will end when response_tx is dropped
write_task
.await
??
;
...
...
@@ -244,82 +393,40 @@ impl SharedTcpServer {
mut
read_half
:
tokio
::
io
::
ReadHalf
<
TcpStream
>
,
handlers
:
Arc
<
DashMap
<
String
,
Arc
<
EndpointHandler
>>>
,
response_tx
:
tokio
::
sync
::
mpsc
::
UnboundedSender
<
Bytes
>
,
work_tx
:
tokio
::
sync
::
mpsc
::
Sender
<
WorkItem
>
,
)
->
Result
<
()
>
{
use
crate
::
pipeline
::
network
::
codec
::{
TcpRequestMessage
,
TcpResponseMessage
};
use
crate
::
pipeline
::
network
::
codec
::{
TcpResponseMessage
,
ZeroCopyTcpDecoder
};
// Create zero-copy decoder with optimized buffer size
let
mut
decoder
=
ZeroCopyTcpDecoder
::
new
();
loop
{
// Read endpoint path length (2 bytes)
let
mut
path_len_buf
=
[
0u8
;
2
];
match
read_half
.read_exact
(
&
mut
path_len_buf
)
.await
{
Ok
(
_
)
=>
{}
// Read one complete message with ZERO copies!
let
request_msg
=
match
decoder
.read_message
(
&
mut
read_half
)
.await
{
Ok
(
msg
)
=>
msg
,
Err
(
e
)
if
e
.kind
()
==
std
::
io
::
ErrorKind
::
UnexpectedEof
=>
{
tracing
::
trace!
(
"Connection closed by peer"
);
break
;
}
Err
(
e
)
=>
{
return
Err
(
e
.into
());
}
}
let
path_len
=
u16
::
from_be_bytes
(
path_len_buf
)
as
usize
;
// Read endpoint path
let
mut
path_buf
=
vec!
[
0u8
;
path_len
];
read_half
.read_exact
(
&
mut
path_buf
)
.await
?
;
// Read headers length (2 bytes)
let
mut
headers_len_buf
=
[
0u8
;
2
];
read_half
.read_exact
(
&
mut
headers_len_buf
)
.await
?
;
let
headers_len
=
u16
::
from_be_bytes
(
headers_len_buf
)
as
usize
;
// Read headers
let
mut
headers_buf
=
vec!
[
0u8
;
headers_len
];
read_half
.read_exact
(
&
mut
headers_buf
)
.await
?
;
// Read payload length (4 bytes)
let
mut
len_buf
=
[
0u8
;
4
];
read_half
.read_exact
(
&
mut
len_buf
)
.await
?
;
let
payload_len
=
u32
::
from_be_bytes
(
len_buf
)
as
usize
;
// Sanity check - enforce maximum message size
let
max_message_size
=
get_max_message_size
();
if
payload_len
>
max_message_size
{
tracing
::
warn!
(
"Request too large: {} bytes (max: {} bytes), closing connection"
,
payload_len
,
max_message_size
);
tracing
::
warn!
(
"Failed to read TCP request: {}"
,
e
);
// Send error response
let
error_response
=
TcpResponseMessage
::
new
(
Bytes
::
from
_static
(
b
"Request too large"
));
TcpResponseMessage
::
new
(
Bytes
::
from
(
format!
(
"Read error: {}"
,
e
)
));
if
let
Ok
(
encoded
)
=
error_response
.encode
()
{
let
_
=
response_tx
.send
(
encoded
);
}
break
;
return
Err
(
e
.into
())
;
}
};
// Read request payload
let
mut
payload_buf
=
vec!
[
0u8
;
payload_len
];
read_half
.read_exact
(
&
mut
payload_buf
)
.await
?
;
// Reconstruct the full message buffer for decoding using BytesMut
let
mut
full_msg
=
BytesMut
::
with_capacity
(
2
+
path_len
+
2
+
headers_len
+
4
+
payload_len
);
full_msg
.extend_from_slice
(
&
path_len_buf
);
full_msg
.extend_from_slice
(
&
path_buf
);
full_msg
.extend_from_slice
(
&
headers_len_buf
);
full_msg
.extend_from_slice
(
&
headers_buf
);
full_msg
.extend_from_slice
(
&
len_buf
);
full_msg
.extend_from_slice
(
&
payload_buf
);
// Decode using codec (zero-copy conversion)
let
full_msg_bytes
=
full_msg
.freeze
();
let
request_msg
=
match
TcpRequestMessage
::
decode
(
&
full_msg_bytes
)
{
Ok
(
msg
)
=>
msg
,
// Get endpoint path (zero-copy string slice)
let
endpoint_path
=
match
request_msg
.endpoint_path
()
{
Ok
(
path
)
=>
path
,
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to decode TCP request: {}"
,
e
);
// Send error response
tracing
::
warn!
(
"Invalid UTF-8 in endpoint path: {}"
,
e
);
let
error_response
=
TcpResponseMessage
::
new
(
Bytes
::
from
(
format!
(
"Decode error: {}"
,
e
)
));
TcpResponseMessage
::
new
(
Bytes
::
from
_static
(
b
"Invalid endpoint path"
));
if
let
Ok
(
encoded
)
=
error_response
.encode
()
{
let
_
=
response_tx
.send
(
encoded
);
}
...
...
@@ -327,18 +434,27 @@ impl SharedTcpServer {
}
};
let
endpoint_path
=
request_msg
.endpoint_path
;
let
headers
=
request_msg
.headers
;
let
payload
=
request_msg
.payload
;
// Get headers (parsed from message)
let
headers
=
request_msg
.headers
();
// Get payload (zero-copy Bytes - just Arc clone!)
let
payload
=
request_msg
.payload
();
tracing
::
trace!
(
endpoint
=
endpoint_path
,
payload_len
=
payload
.len
(),
total_size
=
request_msg
.total_size
(),
"Received TCP request"
);
// Look up handler (lock-free read with DashMap)
let
handler
=
handlers
.get
(
&
endpoint_path
)
.map
(|
h
|
h
.clone
());
let
handler
=
handlers
.get
(
endpoint_path
)
.map
(|
h
|
h
.clone
());
let
handler
=
match
handler
{
Some
(
h
)
=>
h
,
None
=>
{
tracing
::
warn!
(
"No handler found for endpoint: {}"
,
endpoint_path
);
// Send error response
using codec
// Send error response
let
error_response
=
TcpResponseMessage
::
new
(
Bytes
::
from
(
format!
(
"Unknown endpoint: {}"
,
endpoint_path
...
...
@@ -352,54 +468,67 @@ impl SharedTcpServer {
handler
.inflight
.fetch_add
(
1
,
Ordering
::
SeqCst
);
// Send acknowledgment immediately using codec (non-blocking, zero-copy)
// Build work item
// NOTE: payload is Bytes (Arc-counted), so cloning is extremely cheap
let
work_item
=
WorkItem
{
service_handler
:
handler
.service_handler
.clone
(),
payload
,
headers
,
inflight
:
handler
.inflight
.clone
(),
notify
:
handler
.notify
.clone
(),
instance_id
:
handler
.instance_id
,
namespace
:
handler
.namespace
.clone
(),
component_name
:
handler
.component_name
.clone
(),
endpoint_name
:
handler
.endpoint_name
.clone
(),
};
// Send to worker pool with backpressure - BEFORE sending ACK
match
work_tx
.send
(
work_item
)
.await
{
Ok
(
_
)
=>
{
// Send acknowledgment ONLY after successful queuing
let
ack_response
=
TcpResponseMessage
::
empty
();
if
let
Ok
(
encoded_ack
)
=
ack_response
.encode
()
{
// Send to write task without blocking reads
if
response_tx
.send
(
encoded_ack
)
.is_err
()
{
if
let
Ok
(
encoded_ack
)
=
ack_response
.encode
()
&&
response_tx
.send
(
encoded_ack
)
.is_err
()
{
tracing
::
debug!
(
"Write task closed, ending read loop"
);
// Clean up inflight counter since work was queued but ACK failed
handler
.inflight
.fetch_sub
(
1
,
Ordering
::
SeqCst
);
handler
.notify
.notify_one
();
break
;
}
}
// Process request asynchronously
let
service_handler
=
handler
.service_handler
.clone
();
let
inflight
=
handler
.inflight
.clone
();
let
notify
=
handler
.notify
.clone
();
let
instance_id
=
handler
.instance_id
;
let
namespace
=
handler
.namespace
.clone
();
let
component_name
=
handler
.component_name
.clone
();
let
endpoint_name
=
handler
.endpoint_name
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
instance_id
,
"handling TCP request"
);
// Create span with trace context from headers
let
span
=
crate
::
logging
::
make_handle_payload_span_from_tcp_headers
(
&
headers
,
&
component_name
,
&
endpoint_name
,
&
namespace
,
instance_id
,
tracing
::
trace!
(
endpoint
=
handler
.endpoint_name
.as_str
(),
instance_id
=
handler
.instance_id
,
"Request queued and acknowledged"
);
}
Err
(
e
)
=>
{
tracing
::
warn!
(
endpoint
=
handler
.endpoint_name
.as_str
(),
instance_id
=
handler
.instance_id
,
error
=
%
e
,
"Failed to queue work to worker pool, sending error response"
);
let
result
=
service_handler
.handle_payload
(
payload
)
.instrument
(
span
)
.await
;
// Send error response to client instead of ACK
let
error_response
=
TcpResponseMessage
::
new
(
Bytes
::
from
(
format!
(
"Server overloaded: {}"
,
e
)));
if
let
Ok
(
encoded
)
=
error_response
.encode
()
{
let
_
=
response_tx
.send
(
encoded
);
}
match
result
{
Ok
(
_
)
=>
{
tracing
::
trace!
(
instance_id
,
"TCP request handled successfully"
);
// Clean up inflight counter
handler
.inflight
.fetch_sub
(
1
,
Ordering
::
SeqCst
);
handler
.notify
.notify_one
();
// If channel is closed, break the loop
if
matches!
(
e
,
tokio
::
sync
::
mpsc
::
error
::
SendError
(
_
))
{
tracing
::
error!
(
"Worker pool channel closed, shutting down read loop"
);
break
;
}
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to handle TCP request: {}"
,
e
);
}
}
inflight
.fetch_sub
(
1
,
Ordering
::
SeqCst
);
notify
.notify_one
();
});
}
Ok
(())
...
...
@@ -692,4 +821,138 @@ mod tests {
tracing
::
info!
(
"Test passed: unregister_endpoint properly waited for inflight TCP request"
);
}
///////////////////// TESTS FOR CONCURRENCY BOUNDING /////////////////////
/// Mock handler that tracks concurrent execution count
struct
ConcurrencyTrackingHandler
{
/// Current number of concurrent requests being processed
concurrent_count
:
Arc
<
AtomicU64
>
,
/// Maximum concurrent count observed
max_concurrent
:
Arc
<
AtomicU64
>
,
/// Duration to simulate request processing
processing_duration
:
Duration
,
/// Notifies when a request completes
completed
:
Arc
<
Notify
>
,
}
impl
ConcurrencyTrackingHandler
{
fn
new
(
processing_duration
:
Duration
)
->
Self
{
Self
{
concurrent_count
:
Arc
::
new
(
AtomicU64
::
new
(
0
)),
max_concurrent
:
Arc
::
new
(
AtomicU64
::
new
(
0
)),
processing_duration
,
completed
:
Arc
::
new
(
Notify
::
new
()),
}
}
}
#[async_trait]
impl
PushWorkHandler
for
ConcurrencyTrackingHandler
{
async
fn
handle_payload
(
&
self
,
_
payload
:
Bytes
)
->
Result
<
(),
PipelineError
>
{
// Increment concurrent count
let
current
=
self
.concurrent_count
.fetch_add
(
1
,
Ordering
::
SeqCst
)
+
1
;
// Update max if this is higher
self
.max_concurrent
.fetch_max
(
current
,
Ordering
::
SeqCst
);
// Simulate work
tokio
::
time
::
sleep
(
self
.processing_duration
)
.await
;
// Decrement concurrent count
self
.concurrent_count
.fetch_sub
(
1
,
Ordering
::
SeqCst
);
self
.completed
.notify_one
();
Ok
(())
}
fn
add_metrics
(
&
self
,
_
endpoint
:
&
crate
::
component
::
Endpoint
,
_
metrics_labels
:
Option
<&
[(
&
str
,
&
str
)]
>
,
)
->
Result
<
()
>
{
Ok
(())
}
}
#[tokio::test]
async
fn
test_worker_pool_bounds_concurrency
()
{
let
_
=
tracing_subscriber
::
fmt
()
.with_test_writer
()
.with_max_level
(
tracing
::
Level
::
DEBUG
)
.try_init
();
// Use a small pool size for testing
let
pool_size
=
3
;
let
total_requests
=
10
;
// Create bounded channel and dispatcher directly
let
(
work_tx
,
work_rx
)
=
tokio
::
sync
::
mpsc
::
channel
::
<
WorkItem
>
(
total_requests
);
let
cancellation_token
=
CancellationToken
::
new
();
// Start worker pool with small concurrency limit
SharedTcpServer
::
start_worker_pool
(
pool_size
,
work_rx
,
cancellation_token
.clone
());
// Create tracking handler
let
handler
=
Arc
::
new
(
ConcurrencyTrackingHandler
::
new
(
Duration
::
from_millis
(
50
)));
// Create dummy inflight/notify for work items
let
inflight
=
Arc
::
new
(
AtomicU64
::
new
(
0
));
let
notify
=
Arc
::
new
(
Notify
::
new
());
// Send more work items than pool size
for
i
in
0
..
total_requests
{
inflight
.fetch_add
(
1
,
Ordering
::
SeqCst
);
let
work_item
=
WorkItem
{
service_handler
:
handler
.clone
()
as
Arc
<
dyn
PushWorkHandler
>
,
payload
:
Bytes
::
from
(
format!
(
"request {}"
,
i
)),
headers
:
std
::
collections
::
HashMap
::
new
(),
inflight
:
inflight
.clone
(),
notify
:
notify
.clone
(),
instance_id
:
1
,
namespace
:
"test"
.to_string
(),
component_name
:
"test"
.to_string
(),
endpoint_name
:
"test"
.to_string
(),
};
work_tx
.send
(
work_item
)
.await
.expect
(
"send should succeed"
);
}
// Wait for all requests to complete
let
timeout
=
tokio
::
time
::
timeout
(
Duration
::
from_secs
(
5
),
async
{
while
inflight
.load
(
Ordering
::
SeqCst
)
>
0
{
notify
.notified
()
.await
;
}
})
.await
;
assert
!
(
timeout
.is_ok
(),
"All requests should complete within timeout"
);
// Verify concurrency was bounded
let
max_observed
=
handler
.max_concurrent
.load
(
Ordering
::
SeqCst
);
assert
!
(
max_observed
<=
pool_size
as
u64
,
"Max concurrent ({}) should not exceed pool size ({})"
,
max_observed
,
pool_size
);
// Verify all requests completed
assert_eq!
(
inflight
.load
(
Ordering
::
SeqCst
),
0
,
"All requests should have completed"
);
tracing
::
info!
(
"Test passed: max concurrent {} <= pool size {}"
,
max_observed
,
pool_size
);
// Cleanup
cancellation_token
.cancel
();
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment