// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 //! Codec Module //! //! Codec map structure into blobs of bytes and streams of bytes. //! //! In this module, we define three primary codec used to issue single, two-part or multi-part messages, //! on a byte stream. use bytes::Bytes; use tokio_util::{ bytes::{Buf, BufMut, BytesMut}, codec::{Decoder, Encoder}, }; mod two_part; pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType}; /// TCP request plane protocol message with endpoint routing /// /// Wire format: /// - endpoint_path_len: u16 (big-endian) /// - endpoint_path: UTF-8 string /// - payload_len: u32 (big-endian) /// - payload: bytes #[derive(Debug, Clone, PartialEq, Eq)] pub struct TcpRequestMessage { pub endpoint_path: String, pub payload: Bytes, } impl TcpRequestMessage { pub fn new(endpoint_path: String, payload: Bytes) -> Self { Self { endpoint_path, payload, } } /// Encode message to bytes pub fn encode(&self) -> Result { let endpoint_bytes = self.endpoint_path.as_bytes(); let endpoint_len = endpoint_bytes.len(); if endpoint_len > u16::MAX as usize { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("Endpoint path too long: {} bytes", endpoint_len), )); } if self.payload.len() > u32::MAX as usize { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("Payload too large: {} bytes", self.payload.len()), )); } // Use BytesMut for efficient buffer building let mut buf = BytesMut::with_capacity(2 + endpoint_len + 4 + self.payload.len()); // Write endpoint path length (2 bytes) buf.put_u16(endpoint_len as u16); // Write endpoint path buf.put_slice(endpoint_bytes); // Write payload length (4 bytes) buf.put_u32(self.payload.len() as u32); // Write payload buf.put_slice(&self.payload); // Zero-copy conversion to Bytes Ok(buf.freeze()) } /// Decode message from bytes (for backward compatibility, zero-copy when possible) pub fn decode(bytes: &Bytes) -> Result { if bytes.len() < 2 { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "Not enough bytes for endpoint path length", )); } // Read endpoint path length (2 bytes) let endpoint_len = u16::from_be_bytes([bytes[0], bytes[1]]) as usize; let mut offset = 2; if bytes.len() < offset + endpoint_len { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "Not enough bytes for endpoint path", )); } // Read endpoint path (requires copy for UTF-8 validation) let endpoint_path = String::from_utf8(bytes[offset..offset + endpoint_len].to_vec()) .map_err(|e| { std::io::Error::new( std::io::ErrorKind::InvalidData, format!("Invalid UTF-8: {}", e), ) })?; offset += endpoint_len; if bytes.len() < offset + 4 { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "Not enough bytes for payload length", )); } // Read payload length (4 bytes) let payload_len = u32::from_be_bytes([ bytes[offset], bytes[offset + 1], bytes[offset + 2], bytes[offset + 3], ]) as usize; offset += 4; if bytes.len() < offset + payload_len { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, format!( "Not enough bytes for payload: expected {}, got {}", payload_len, bytes.len() - offset ), )); } // Read payload (zero-copy slice) let payload = bytes.slice(offset..offset + payload_len); Ok(Self { endpoint_path, payload, }) } } /// Codec for encoding/decoding TcpRequestMessage /// Supports max_message_size enforcement #[derive(Clone, Default)] pub struct TcpRequestCodec { max_message_size: Option, } impl TcpRequestCodec { pub fn new(max_message_size: Option) -> Self { Self { max_message_size } } } impl Decoder for TcpRequestCodec { type Item = TcpRequestMessage; type Error = std::io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { // Need at least 2 bytes for endpoint_path_len if src.len() < 2 { return Ok(None); } // Peek at endpoint path length without consuming let endpoint_len = u16::from_be_bytes([src[0], src[1]]) as usize; let header_size = 2 + endpoint_len + 4; // path_len + path + payload_len if src.len() < header_size { return Ok(None); } // Peek at payload length let payload_len_offset = 2 + endpoint_len; let payload_len = u32::from_be_bytes([ src[payload_len_offset], src[payload_len_offset + 1], src[payload_len_offset + 2], src[payload_len_offset + 3], ]) as usize; let total_len = header_size + payload_len; // Check max message size if let Some(max_size) = self.max_message_size && total_len > max_size { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!( "Request too large: {} bytes (max: {} bytes)", total_len, max_size ), )); } // Check if we have the full message if src.len() < total_len { return Ok(None); } // We have a complete message, advance past length prefix src.advance(2); // Read endpoint path let endpoint_bytes = src.split_to(endpoint_len); let endpoint_path = String::from_utf8(endpoint_bytes.to_vec()).map_err(|e| { std::io::Error::new( std::io::ErrorKind::InvalidData, format!("Invalid UTF-8 in endpoint path: {}", e), ) })?; // Advance past payload length src.advance(4); // Read payload let payload = src.split_to(payload_len).freeze(); Ok(Some(TcpRequestMessage { endpoint_path, payload, })) } } impl Encoder for TcpRequestCodec { type Error = std::io::Error; fn encode(&mut self, item: TcpRequestMessage, dst: &mut BytesMut) -> Result<(), Self::Error> { let endpoint_bytes = item.endpoint_path.as_bytes(); let endpoint_len = endpoint_bytes.len(); if endpoint_len > u16::MAX as usize { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("Endpoint path too long: {} bytes", endpoint_len), )); } if item.payload.len() > u32::MAX as usize { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("Payload too large: {} bytes", item.payload.len()), )); } let total_len = 2 + endpoint_len + 4 + item.payload.len(); // Check max message size if let Some(max_size) = self.max_message_size && total_len > max_size { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!( "Request too large: {} bytes (max: {} bytes)", total_len, max_size ), )); } // Reserve space dst.reserve(total_len); // Write endpoint path length dst.put_u16(endpoint_len as u16); // Write endpoint path dst.put_slice(endpoint_bytes); // Write payload length dst.put_u32(item.payload.len() as u32); // Write payload dst.put_slice(&item.payload); Ok(()) } } /// TCP response message (acknowledgment or error) /// /// Wire format: /// - length: u32 (big-endian) /// - data: bytes #[derive(Debug, Clone, PartialEq, Eq)] pub struct TcpResponseMessage { pub data: Bytes, } impl TcpResponseMessage { pub fn new(data: Bytes) -> Self { Self { data } } pub fn empty() -> Self { Self { data: Bytes::new() } } /// Encode response to bytes (for backward compatibility) pub fn encode(&self) -> Result { if self.data.len() > u32::MAX as usize { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("Response too large: {} bytes", self.data.len()), )); } // Use BytesMut for efficient buffer building let mut buf = BytesMut::with_capacity(4 + self.data.len()); // Write length (4 bytes) buf.put_u32(self.data.len() as u32); // Write data buf.put_slice(&self.data); // Zero-copy conversion to Bytes Ok(buf.freeze()) } /// Decode response from bytes (for backward compatibility, zero-copy when possible) pub fn decode(bytes: &Bytes) -> Result { if bytes.len() < 4 { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "Not enough bytes for response length", )); } // Read length (4 bytes) let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; if bytes.len() < 4 + len { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, format!( "Not enough bytes for response: expected {}, got {}", len, bytes.len() - 4 ), )); } // Read data (zero-copy slice) let data = bytes.slice(4..4 + len); Ok(Self { data }) } } /// Codec for encoding/decoding TcpResponseMessage /// Supports max_message_size enforcement #[derive(Clone, Default)] pub struct TcpResponseCodec { max_message_size: Option, } impl TcpResponseCodec { pub fn new(max_message_size: Option) -> Self { Self { max_message_size } } } impl Decoder for TcpResponseCodec { type Item = TcpResponseMessage; type Error = std::io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { // Need at least 4 bytes for length if src.len() < 4 { return Ok(None); } // Peek at message length without consuming let data_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize; let total_len = 4 + data_len; // Check max message size if let Some(max_size) = self.max_message_size && total_len > max_size { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!( "Response too large: {} bytes (max: {} bytes)", total_len, max_size ), )); } // Check if we have the full message if src.len() < total_len { return Ok(None); } // Advance past the length prefix src.advance(4); // Read data let data = src.split_to(data_len).freeze(); Ok(Some(TcpResponseMessage { data })) } } impl Encoder for TcpResponseCodec { type Error = std::io::Error; fn encode(&mut self, item: TcpResponseMessage, dst: &mut BytesMut) -> Result<(), Self::Error> { if item.data.len() > u32::MAX as usize { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("Response too large: {} bytes", item.data.len()), )); } let total_len = 4 + item.data.len(); // Check max message size if let Some(max_size) = self.max_message_size && total_len > max_size { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, format!( "Response too large: {} bytes (max: {} bytes)", total_len, max_size ), )); } // Reserve space dst.reserve(total_len); // Write length dst.put_u32(item.data.len() as u32); // Write data dst.put_slice(&item.data); Ok(()) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tcp_request_encode_decode() { let msg = TcpRequestMessage::new( "test.endpoint".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]), ); let encoded = msg.encode().unwrap(); let decoded = TcpRequestMessage::decode(&encoded).unwrap(); assert_eq!(decoded, msg); } #[test] fn test_tcp_request_empty_payload() { let msg = TcpRequestMessage::new("test".to_string(), Bytes::new()); let encoded = msg.encode().unwrap(); let decoded = TcpRequestMessage::decode(&encoded).unwrap(); assert_eq!(decoded, msg); } #[test] fn test_tcp_request_large_payload() { let payload = Bytes::from(vec![42u8; 1024 * 1024]); // 1MB let msg = TcpRequestMessage::new("large".to_string(), payload); let encoded = msg.encode().unwrap(); let decoded = TcpRequestMessage::decode(&encoded).unwrap(); assert_eq!(decoded, msg); } #[test] fn test_tcp_request_decode_truncated() { let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5])); let encoded = msg.encode().unwrap(); // Truncate the encoded message let truncated = encoded.slice(..encoded.len() - 2); let result = TcpRequestMessage::decode(&truncated); assert!(result.is_err()); } #[test] fn test_tcp_response_encode_decode() { let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5])); let encoded = msg.encode().unwrap(); let decoded = TcpResponseMessage::decode(&encoded).unwrap(); assert_eq!(decoded, msg); } #[test] fn test_tcp_response_empty() { let msg = TcpResponseMessage::empty(); let encoded = msg.encode().unwrap(); let decoded = TcpResponseMessage::decode(&encoded).unwrap(); assert_eq!(decoded, msg); assert_eq!(decoded.data.len(), 0); } #[test] fn test_tcp_response_decode_truncated() { let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5])); let encoded = msg.encode().unwrap(); // Truncate the encoded message let truncated = encoded.slice(..3); let result = TcpResponseMessage::decode(&truncated); assert!(result.is_err()); } #[test] fn test_tcp_request_unicode_endpoint() { let msg = TcpRequestMessage::new("тест.端点".to_string(), Bytes::from(vec![1, 2, 3])); let encoded = msg.encode().unwrap(); let decoded = TcpRequestMessage::decode(&encoded).unwrap(); assert_eq!(decoded, msg); } #[test] fn test_tcp_request_codec() { use tokio_util::codec::{Decoder, Encoder}; let msg = TcpRequestMessage::new( "test.endpoint".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]), ); let mut codec = TcpRequestCodec::new(None); let mut buf = BytesMut::new(); // Encode codec.encode(msg.clone(), &mut buf).unwrap(); // Decode let decoded = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(decoded, msg); } #[test] fn test_tcp_request_codec_partial() { use tokio_util::codec::Decoder; let msg = TcpRequestMessage::new( "test.endpoint".to_string(), Bytes::from(vec![1, 2, 3, 4, 5]), ); let encoded = msg.encode().unwrap(); let mut codec = TcpRequestCodec::new(None); // Feed partial data let mut buf = BytesMut::from(&encoded[..5]); assert!(codec.decode(&mut buf).unwrap().is_none()); // Feed rest of data buf.extend_from_slice(&encoded[5..]); let decoded = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(decoded, msg); } #[test] fn test_tcp_request_codec_max_size() { use tokio_util::codec::Encoder; let msg = TcpRequestMessage::new("test".to_string(), Bytes::from(vec![1, 2, 3, 4, 5])); let mut codec = TcpRequestCodec::new(Some(10)); // Too small let mut buf = BytesMut::new(); let result = codec.encode(msg, &mut buf); assert!(result.is_err()); } #[test] fn test_tcp_response_codec() { use tokio_util::codec::{Decoder, Encoder}; let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5])); let mut codec = TcpResponseCodec::new(None); let mut buf = BytesMut::new(); // Encode codec.encode(msg.clone(), &mut buf).unwrap(); // Decode let decoded = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(decoded, msg); } #[test] fn test_tcp_response_codec_partial() { use tokio_util::codec::Decoder; let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5])); let encoded = msg.encode().unwrap(); let mut codec = TcpResponseCodec::new(None); // Feed partial data let mut buf = BytesMut::from(&encoded[..3]); assert!(codec.decode(&mut buf).unwrap().is_none()); // Feed rest of data buf.extend_from_slice(&encoded[3..]); let decoded = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(decoded, msg); } #[test] fn test_tcp_response_codec_max_size() { use tokio_util::codec::Encoder; let msg = TcpResponseMessage::new(Bytes::from(vec![1, 2, 3, 4, 5])); let mut codec = TcpResponseCodec::new(Some(5)); // Too small let mut buf = BytesMut::new(); let result = codec.encode(msg, &mut buf); assert!(result.is_err()); } /// Demonstrates how framed codec enables testability without actual TCP connections #[tokio::test] async fn test_framed_codec_integration() { use futures::{SinkExt, StreamExt}; use std::io::Cursor; use tokio_util::codec::{FramedRead, FramedWrite}; // Simulate a duplex connection using in-memory buffer let mut buffer = Vec::new(); // Writer side: encode requests { let cursor = Cursor::new(&mut buffer); let mut writer = FramedWrite::new(cursor, TcpRequestCodec::new(None)); let msg1 = TcpRequestMessage::new("endpoint1".to_string(), Bytes::from("data1")); let msg2 = TcpRequestMessage::new("endpoint2".to_string(), Bytes::from("data2")); writer.send(msg1).await.unwrap(); writer.send(msg2).await.unwrap(); } // Reader side: decode requests { let cursor = Cursor::new(&buffer[..]); let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None)); let decoded1 = reader.next().await.unwrap().unwrap(); assert_eq!(decoded1.endpoint_path, "endpoint1"); assert_eq!(decoded1.payload, Bytes::from("data1")); let decoded2 = reader.next().await.unwrap().unwrap(); assert_eq!(decoded2.endpoint_path, "endpoint2"); assert_eq!(decoded2.payload, Bytes::from("data2")); } } /// Demonstrates testing partial message handling #[tokio::test] async fn test_framed_codec_partial_messages() { use futures::StreamExt; use std::io::Cursor; use tokio_util::codec::FramedRead; // Create a message and encode it let msg = TcpRequestMessage::new("test".to_string(), Bytes::from("hello")); let encoded = msg.encode().unwrap(); // Split the encoded message into chunks let chunk1 = &encoded[..5]; let chunk2 = &encoded[5..]; // Create a buffer that simulates receiving data in chunks let mut full_buffer = Vec::new(); full_buffer.extend_from_slice(chunk1); // Reader can't decode yet (partial data) { let cursor = Cursor::new(&full_buffer[..]); let _reader = FramedRead::new(cursor, TcpRequestCodec::new(None)); // In real async, this would return Ok(None) and wait for more data // For Cursor, it returns None at EOF } // Add the rest of the data full_buffer.extend_from_slice(chunk2); // Now decoding succeeds { let cursor = Cursor::new(&full_buffer[..]); let mut reader = FramedRead::new(cursor, TcpRequestCodec::new(None)); let decoded = reader.next().await.unwrap().unwrap(); assert_eq!(decoded.endpoint_path, "test"); assert_eq!(decoded.payload, Bytes::from("hello")); } } }