Unverified Commit e47457e4 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

test: unit test tcp/client.rs handle_writer [1/n] (#5055)

parent a170b31f
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
pub mod client; pub mod client;
pub mod server; pub mod server;
pub mod test_utils;
use super::ControlMessage; use super::ControlMessage;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
......
...@@ -324,3 +324,399 @@ async fn handle_writer( ...@@ -324,3 +324,399 @@ async fn handle_writer(
drop(alive_rx); drop(alive_rx);
Ok(framed_writer) Ok(framed_writer)
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::context::Controller;
use crate::pipeline::network::tcp::test_utils::create_tcp_pair;
use bytes::Bytes;
use futures::StreamExt;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tokio_util::codec::FramedRead;
struct WriterHarness {
server: tokio::net::TcpStream,
framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
bytes_tx: mpsc::Sender<TwoPartMessage>,
bytes_rx: mpsc::Receiver<TwoPartMessage>,
alive_tx: oneshot::Sender<()>,
alive_rx: oneshot::Receiver<()>,
controller: Arc<Controller>,
}
/// Creates a reusable writer harness with paired TCP streams and test channels.
async fn writer_harness() -> WriterHarness {
let (client, server) = create_tcp_pair().await;
let (_, write_half) = tokio::io::split(client);
let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
let (bytes_tx, bytes_rx) = mpsc::channel(64);
let (alive_tx, alive_rx) = oneshot::channel::<()>();
let controller = Arc::new(Controller::default());
WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_tx,
alive_rx,
controller,
}
}
async fn recv_msg(reader: &mut FramedRead<TcpStream, TwoPartCodec>) -> TwoPartMessage {
reader
.next()
.await
.expect("expected message")
.expect("failed to decode message")
}
fn assert_data_only_message(msg: TwoPartMessage, expected: &[u8]) {
let (header, data) = msg.optional_parts();
assert!(header.is_none(), "data-only message should not have header");
assert_eq!(
data.expect("data payload missing").as_ref(),
expected,
"data payload should match"
);
}
fn assert_header_only_message(msg: TwoPartMessage, expected: &[u8]) {
let (header, data) = msg.optional_parts();
assert!(data.is_none(), "header-only message should not carry data");
assert_eq!(
header.expect("header missing").as_ref(),
expected,
"header payload should match"
);
}
fn assert_header_and_data_message(
msg: TwoPartMessage,
expected_header: &[u8],
expected_data: &[u8],
) {
let (header, data) = msg.optional_parts();
assert_eq!(
header.expect("header missing").as_ref(),
expected_header,
"header payload should match"
);
assert_eq!(
data.expect("data missing").as_ref(),
expected_data,
"data payload should match"
);
}
fn assert_sentinel_message(msg: TwoPartMessage) {
let (header, data) = msg.optional_parts();
assert!(data.is_none(), "sentinel should not include a data section");
let expected_sentinel = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
assert_eq!(
header.expect("sentinel header missing").as_ref(),
expected_sentinel.as_slice(),
"sentinel header should match serialized ControlMessage::Sentinel"
);
}
/// Test that handle_writer forwards messages from the channel to the framed writer
#[tokio::test]
async fn test_handle_writer_forwards_messages() {
let WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
// Send test messages
let test_msg = TwoPartMessage::from_data(Bytes::from("test data"));
bytes_tx.send(test_msg).await.unwrap();
// Close the sender to trigger normal termination
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
// Decode from server side to verify data and sentinel were sent
let mut reader = FramedRead::new(server, TwoPartCodec::default());
let msg = recv_msg(&mut reader).await;
assert_data_only_message(msg, b"test data");
let sentinel = recv_msg(&mut reader).await;
assert_sentinel_message(sentinel);
}
/// Test that handle_writer sends sentinel on normal channel closure
#[tokio::test]
async fn test_handle_writer_sends_sentinel_on_normal_closure() {
let WriterHarness {
mut server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
// Close the sender immediately to trigger normal termination
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
// Read from server side to verify sentinel was sent
let mut buffer = vec![0u8; 1024];
let n = server.read(&mut buffer).await.unwrap();
// Buffer should contain the sentinel message
assert!(n > 0, "Expected sentinel to be written to the TCP stream");
// Verify it contains the sentinel message by checking for the JSON
let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
assert!(
buffer[..n]
.windows(sentinel_json.len())
.any(|w| w == sentinel_json.as_slice()),
"Buffer should contain sentinel message. Buffer: {:?}",
String::from_utf8_lossy(&buffer[..n])
);
}
/// Test that handle_writer does NOT send sentinel when context is killed
#[tokio::test]
async fn test_handle_writer_no_sentinel_on_context_killed() {
let WriterHarness {
mut server,
framed_writer,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
// Kill the context
controller.kill();
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
// Drop the writer to close the connection, then try to read. Otherwise,
// the test will hang on `server.read()`
drop(result);
// Read from server side - should get no sentinel
let mut buffer = vec![0u8; 1024];
let n = server.read(&mut buffer).await.unwrap();
// Buffer should be empty (no sentinel sent)
let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
assert!(
n == 0
|| !buffer[..n]
.windows(sentinel_json.len())
.any(|w| w == sentinel_json.as_slice()),
"Buffer should NOT contain sentinel message when context is killed"
);
}
/// Test that handle_writer does NOT send sentinel when context is stopped
#[tokio::test]
async fn test_handle_writer_no_sentinel_on_context_stopped() {
let WriterHarness {
mut server,
framed_writer,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
// Stop the context
controller.stop();
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
// Drop the writer to close the connection, then try to read. Otherwise,
// the test will hang on `server.read()`
drop(result);
// Read from server side - should get no sentinel
let mut buffer = vec![0u8; 1024];
let n = server.read(&mut buffer).await.unwrap();
// Buffer should be empty (no sentinel sent)
let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
assert!(
n == 0
|| !buffer[..n]
.windows(sentinel_json.len())
.any(|w| w == sentinel_json.as_slice()),
"Buffer should NOT contain sentinel message when context is stopped"
);
}
/// Test that handle_writer handles multiple messages correctly
#[tokio::test]
async fn test_handle_writer_multiple_messages() {
let WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
// Send multiple messages
for i in 0..5 {
let test_msg = TwoPartMessage::from_data(Bytes::from(format!("message {}", i)));
bytes_tx.send(test_msg).await.unwrap();
}
// Close the sender to trigger normal termination
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
// Decode from server side to verify all messages plus sentinel
let mut reader = FramedRead::new(server, TwoPartCodec::default());
for i in 0..5 {
let msg = recv_msg(&mut reader).await;
assert_data_only_message(msg, format!("message {}", i).as_bytes());
}
let sentinel = recv_msg(&mut reader).await;
assert_sentinel_message(sentinel);
}
/// Test that alive_rx is dropped after handle_writer completes
#[tokio::test]
async fn test_handle_writer_drops_alive_rx() {
let WriterHarness {
framed_writer,
bytes_tx,
bytes_rx,
alive_tx,
alive_rx,
controller,
..
} = writer_harness().await;
// Close the sender to trigger normal termination
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
// alive_tx should now be closed because alive_rx was dropped
assert!(alive_tx.is_closed());
}
/// Test handle_writer with header-only messages (control messages)
#[tokio::test]
async fn test_handle_writer_header_only_messages() {
let WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
// Send a header-only message
let header_msg = TwoPartMessage::from_header(Bytes::from("header content"));
bytes_tx.send(header_msg).await.unwrap();
// Close the sender
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
let mut reader = FramedRead::new(server, TwoPartCodec::default());
let header_msg = recv_msg(&mut reader).await;
assert_header_only_message(header_msg, b"header content");
let sentinel = recv_msg(&mut reader).await;
assert_sentinel_message(sentinel);
}
/// Test handle_writer with mixed header and data messages
#[tokio::test]
async fn test_handle_writer_mixed_messages() {
let WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
// Send mixed messages
bytes_tx
.send(TwoPartMessage::from_header(Bytes::from("header1")))
.await
.unwrap();
bytes_tx
.send(TwoPartMessage::from_data(Bytes::from("data1")))
.await
.unwrap();
bytes_tx
.send(TwoPartMessage::from_parts(
Bytes::from("header2"),
Bytes::from("data2"),
))
.await
.unwrap();
// Close the sender
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
let mut reader = FramedRead::new(server, TwoPartCodec::default());
let first = recv_msg(&mut reader).await;
assert_header_only_message(first, b"header1");
let second = recv_msg(&mut reader).await;
assert_data_only_message(second, b"data1");
let third = recv_msg(&mut reader).await;
assert_header_and_data_message(third, b"header2", b"data2");
let sentinel = recv_msg(&mut reader).await;
assert_sentinel_message(sentinel);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Test utilities shared across TCP transport tests.
use tokio::net::TcpListener;
/// Creates a connected TCP pair for testing.
///
/// Returns a tuple of (client, server) TcpStream instances that are connected to each other.
/// This is useful for testing functions that operate on TCP streams without needing
/// actual network communication.
pub async fn create_tcp_pair() -> (tokio::net::TcpStream, tokio::net::TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::net::TcpStream::connect(addr).await.unwrap();
let (server, _) = listener.accept().await.unwrap();
(client, server)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment