"lib/bindings/vscode:/vscode.git/clone" did not exist on "268d017e24c145a514fd267fa976b6e55a01bc44"
Unverified Commit ace35a8e authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

fix: distributed tracing propagation for TCP transport (#5283)


Co-authored-by: default avatarIshan Dhanani <ishandhanani@gmail.com>
parent d0cfc40a
......@@ -278,6 +278,8 @@ pub fn make_request_span<B>(req: &Request<B>) -> Span {
let version = format!("{:?}", req.version());
let trace_parent = TraceParent::from_headers(req.headers());
let otel_context = extract_otel_context_from_http_headers(req.headers());
let span = tracing::info_span!(
"http-request",
method = %method,
......@@ -286,12 +288,52 @@ pub fn make_request_span<B>(req: &Request<B>) -> Span {
trace_id = trace_parent.trace_id,
parent_id = trace_parent.parent_id,
x_request_id = trace_parent.x_request_id,
x_dynamo_request_id = trace_parent.x_dynamo_request_id,
x_dynamo_request_id = trace_parent.x_dynamo_request_id,
);
if let Some(context) = otel_context {
let _ = span.set_parent(context);
}
span
}
/// Extract OpenTelemetry context from HTTP headers for distributed tracing
fn extract_otel_context_from_http_headers(
headers: &http::HeaderMap,
) -> Option<opentelemetry::Context> {
let traceparent_value = headers.get("traceparent")?.to_str().ok()?;
struct HttpHeaderExtractor<'a>(&'a http::HeaderMap);
impl<'a> Extractor for HttpHeaderExtractor<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|v| v.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
vec!["traceparent", "tracestate"]
.into_iter()
.filter(|&key| self.0.get(key).is_some())
.collect()
}
}
// Early return if traceparent is empty
if traceparent_value.is_empty() {
return None;
}
let extractor = HttpHeaderExtractor(headers);
let otel_context = TRACE_PROPAGATOR.extract(&extractor);
if otel_context.span().span_context().is_valid() {
Some(otel_context)
} else {
None
}
}
/// Create a handle_payload span from NATS headers with component context
pub fn make_handle_payload_span(
headers: &async_nats::HeaderMap,
......@@ -335,6 +377,93 @@ pub fn make_handle_payload_span(
}
}
/// Create a handle_payload span from TCP/HashMap headers with component context
pub fn make_handle_payload_span_from_tcp_headers(
headers: &std::collections::HashMap<String, String>,
component: &str,
endpoint: &str,
namespace: &str,
instance_id: u64,
) -> Span {
let (otel_context, trace_id, parent_span_id) = extract_otel_context_from_tcp_headers(headers);
let x_request_id = headers.get("x-request-id").cloned();
let x_dynamo_request_id = headers.get("x-dynamo-request-id").cloned();
let tracestate = headers.get("tracestate").cloned();
if let (Some(trace_id), Some(parent_id)) = (trace_id.as_ref(), parent_span_id.as_ref()) {
let span = tracing::info_span!(
"handle_payload",
trace_id = trace_id.as_str(),
parent_id = parent_id.as_str(),
x_request_id = x_request_id,
x_dynamo_request_id = x_dynamo_request_id,
tracestate = tracestate,
component = component,
endpoint = endpoint,
namespace = namespace,
instance_id = instance_id,
);
if let Some(context) = otel_context {
let _ = span.set_parent(context);
}
span
} else {
tracing::info_span!(
"handle_payload",
x_request_id = x_request_id,
x_dynamo_request_id = x_dynamo_request_id,
tracestate = tracestate,
component = component,
endpoint = endpoint,
namespace = namespace,
instance_id = instance_id,
)
}
}
/// Extract OpenTelemetry trace context from TCP/HashMap headers for distributed tracing
fn extract_otel_context_from_tcp_headers(
headers: &std::collections::HashMap<String, String>,
) -> (
Option<opentelemetry::Context>,
Option<String>,
Option<String>,
) {
let traceparent_value = match headers.get("traceparent") {
Some(value) => value.as_str(),
None => return (None, None, None),
};
let (trace_id, parent_span_id) = parse_traceparent(traceparent_value);
struct TcpHeaderExtractor<'a>(&'a std::collections::HashMap<String, String>);
impl<'a> Extractor for TcpHeaderExtractor<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).map(|s| s.as_str())
}
fn keys(&self) -> Vec<&str> {
vec!["traceparent", "tracestate"]
.into_iter()
.filter(|&key| self.0.get(key).is_some())
.collect()
}
}
let extractor = TcpHeaderExtractor(headers);
let otel_context = TRACE_PROPAGATOR.extract(&extractor);
let context_with_trace = if otel_context.span().span_context().is_valid() {
Some(otel_context)
} else {
None
};
(context_with_trace, trace_id, parent_span_id)
}
/// Extract OpenTelemetry trace context from NATS headers for distributed tracing
pub fn extract_otel_context_from_nats_headers(
headers: &async_nats::HeaderMap,
......@@ -366,8 +495,7 @@ pub fn extract_otel_context_from_nats_headers(
}
let extractor = NatsHeaderExtractor(headers);
let propagator = opentelemetry_sdk::propagation::TraceContextPropagator::new();
let otel_context = propagator.extract(&extractor);
let otel_context = TRACE_PROPAGATOR.extract(&extractor);
let context_with_trace = if otel_context.span().span_context().is_valid() {
Some(otel_context)
......@@ -394,8 +522,7 @@ pub fn inject_otel_context_into_nats_headers(
}
let mut injector = NatsHeaderInjector(headers);
let propagator = opentelemetry_sdk::propagation::TraceContextPropagator::new();
propagator.inject_context(&otel_context, &mut injector);
TRACE_PROPAGATOR.inject_context(&otel_context, &mut injector);
}
/// Inject trace context from current span into NATS headers
......@@ -948,6 +1075,11 @@ impl CustomJsonFormatter {
use once_cell::sync::Lazy;
use regex::Regex;
/// Static W3C Trace Context propagator instance to avoid repeated allocations
static TRACE_PROPAGATOR: Lazy<opentelemetry_sdk::propagation::TraceContextPropagator> =
Lazy::new(opentelemetry_sdk::propagation::TraceContextPropagator::new);
fn parse_tracing_duration(s: &str) -> Option<u64> {
static RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r#"^["']?\s*([0-9.]+)\s*(µs|us|ns|ms|s)\s*["']?$"#).unwrap());
......
......@@ -18,16 +18,19 @@ mod two_part;
pub use two_part::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
/// TCP request plane protocol message with endpoint routing
/// TCP request plane protocol message with endpoint routing and trace headers
///
/// Wire format:
/// - endpoint_path_len: u16 (big-endian)
/// - endpoint_path: UTF-8 string
/// - headers_len: u16 (big-endian)
/// - headers: JSON-encoded HashMap<String, String>
/// - payload_len: u32 (big-endian)
/// - payload: bytes
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TcpRequestMessage {
pub endpoint_path: String,
pub headers: std::collections::HashMap<String, String>,
pub payload: Bytes,
}
......@@ -35,6 +38,19 @@ impl TcpRequestMessage {
pub fn new(endpoint_path: String, payload: Bytes) -> Self {
Self {
endpoint_path,
headers: std::collections::HashMap::new(),
payload,
}
}
pub fn with_headers(
endpoint_path: String,
headers: std::collections::HashMap<String, String>,
payload: Bytes,
) -> Self {
Self {
endpoint_path,
headers,
payload,
}
}
......@@ -51,6 +67,22 @@ impl TcpRequestMessage {
));
}
// Encode headers as JSON
let headers_json = serde_json::to_vec(&self.headers).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Failed to encode headers: {}", e),
)
})?;
let headers_len = headers_json.len();
if headers_len > u16::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Headers too large: {} bytes", headers_len),
));
}
if self.payload.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
......@@ -59,7 +91,8 @@ impl TcpRequestMessage {
}
// Use BytesMut for efficient buffer building
let mut buf = BytesMut::with_capacity(2 + endpoint_len + 4 + self.payload.len());
let mut buf =
BytesMut::with_capacity(2 + endpoint_len + 2 + headers_len + 4 + self.payload.len());
// Write endpoint path length (2 bytes)
buf.put_u16(endpoint_len as u16);
......@@ -67,6 +100,12 @@ impl TcpRequestMessage {
// Write endpoint path
buf.put_slice(endpoint_bytes);
// Write headers length (2 bytes)
buf.put_u16(headers_len as u16);
// Write headers
buf.put_slice(&headers_json);
// Write payload length (4 bytes)
buf.put_u32(self.payload.len() as u32);
......@@ -102,11 +141,39 @@ impl TcpRequestMessage {
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Invalid UTF-8: {}", e),
format!("Invalid UTF-8 in endpoint path: {}", e),
)
})?;
offset += endpoint_len;
if bytes.len() < offset + 2 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for headers length",
));
}
// Read headers length (2 bytes)
let headers_len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]) as usize;
offset += 2;
if bytes.len() < offset + headers_len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Not enough bytes for headers",
));
}
// Read and parse headers
let headers: std::collections::HashMap<String, String> =
serde_json::from_slice(&bytes[offset..offset + headers_len]).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Invalid JSON in headers: {}", e),
)
})?;
offset += headers_len;
if bytes.len() < offset + 4 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
......@@ -139,6 +206,7 @@ impl TcpRequestMessage {
Ok(Self {
endpoint_path,
headers,
payload,
})
}
......@@ -169,14 +237,25 @@ impl Decoder for TcpRequestCodec {
// Peek at endpoint path length without consuming
let endpoint_len = u16::from_be_bytes([src[0], src[1]]) as usize;
let header_size = 2 + endpoint_len + 4; // path_len + path + payload_len
// Need path + headers_len
if src.len() < 2 + endpoint_len + 2 {
return Ok(None);
}
// Peek at headers length
let headers_len_offset = 2 + endpoint_len;
let headers_len =
u16::from_be_bytes([src[headers_len_offset], src[headers_len_offset + 1]]) as usize;
// Need path + headers + payload_len
let header_size = 2 + endpoint_len + 2 + headers_len + 4;
if src.len() < header_size {
return Ok(None);
}
// Peek at payload length
let payload_len_offset = 2 + endpoint_len;
let payload_len_offset = 2 + endpoint_len + 2 + headers_len;
let payload_len = u32::from_be_bytes([
src[payload_len_offset],
src[payload_len_offset + 1],
......@@ -204,7 +283,7 @@ impl Decoder for TcpRequestCodec {
return Ok(None);
}
// We have a complete message, advance past length prefix
// We have a complete message, advance past endpoint path length prefix
src.advance(2);
// Read endpoint path
......@@ -216,6 +295,19 @@ impl Decoder for TcpRequestCodec {
)
})?;
// Advance past headers length
src.advance(2);
// Read and parse headers
let headers_bytes = src.split_to(headers_len);
let headers: std::collections::HashMap<String, String> =
serde_json::from_slice(&headers_bytes).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Invalid JSON in headers: {}", e),
)
})?;
// Advance past payload length
src.advance(4);
......@@ -224,6 +316,7 @@ impl Decoder for TcpRequestCodec {
Ok(Some(TcpRequestMessage {
endpoint_path,
headers,
payload,
}))
}
......@@ -243,6 +336,22 @@ impl Encoder<TcpRequestMessage> for TcpRequestCodec {
));
}
// Encode headers as JSON
let headers_json = serde_json::to_vec(&item.headers).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Failed to encode headers: {}", e),
)
})?;
let headers_len = headers_json.len();
if headers_len > u16::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Headers too large: {} bytes", headers_len),
));
}
if item.payload.len() > u32::MAX as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
......@@ -250,7 +359,7 @@ impl Encoder<TcpRequestMessage> for TcpRequestCodec {
));
}
let total_len = 2 + endpoint_len + 4 + item.payload.len();
let total_len = 2 + endpoint_len + 2 + headers_len + 4 + item.payload.len();
// Check max message size
if let Some(max_size) = self.max_message_size
......@@ -274,6 +383,12 @@ impl Encoder<TcpRequestMessage> for TcpRequestCodec {
// Write endpoint path
dst.put_slice(endpoint_bytes);
// Write headers length
dst.put_u16(headers_len as u16);
// Write headers
dst.put_slice(&headers_json);
// Write payload length
dst.put_u32(item.payload.len() as u32);
......
......@@ -324,7 +324,8 @@ impl TcpConnection {
// Encode request on caller's thread (hot path optimization)
// This allows multiple concurrent callers to encode in parallel
// rather than serializing through the writer task
let request_msg = TcpRequestMessage::new(endpoint_path, payload);
// Include all headers (especially trace headers) in the message
let request_msg = TcpRequestMessage::with_headers(endpoint_path, headers.clone(), payload);
let encoded_data = request_msg.encode()?;
// Create response channel
......@@ -657,7 +658,7 @@ mod tests {
let (stream, _) = listener.accept().await.unwrap();
let (mut read_half, mut write_half) = tokio::io::split(stream);
// Read request
// Read path length and path
let mut len_buf = [0u8; 2];
read_half.read_exact(&mut len_buf).await.unwrap();
let path_len = u16::from_be_bytes(len_buf) as usize;
......@@ -665,6 +666,15 @@ mod tests {
let mut path_buf = vec![0u8; path_len];
read_half.read_exact(&mut path_buf).await.unwrap();
// Read headers length and headers
let mut headers_len_buf = [0u8; 2];
read_half.read_exact(&mut headers_len_buf).await.unwrap();
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
read_half.read_exact(&mut headers_buf).await.unwrap();
// Read payload length and payload
let mut len_buf = [0u8; 4];
read_half.read_exact(&mut len_buf).await.unwrap();
let payload_len = u32::from_be_bytes(len_buf) as usize;
......@@ -728,6 +738,17 @@ mod tests {
break;
}
let mut headers_len_buf = [0u8; 2];
if read_half.read_exact(&mut headers_len_buf).await.is_err() {
break;
}
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
if read_half.read_exact(&mut headers_buf).await.is_err() {
break;
}
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
......@@ -826,6 +847,17 @@ mod tests {
break;
}
let mut headers_len_buf = [0u8; 2];
if read_half.read_exact(&mut headers_len_buf).await.is_err() {
break;
}
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
let mut headers_buf = vec![0u8; headers_len];
if read_half.read_exact(&mut headers_buf).await.is_err() {
break;
}
let mut len_buf = [0u8; 4];
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
......
......@@ -266,6 +266,15 @@ impl SharedTcpServer {
let mut path_buf = vec![0u8; path_len];
read_half.read_exact(&mut path_buf).await?;
// Read headers length (2 bytes)
let mut headers_len_buf = [0u8; 2];
read_half.read_exact(&mut headers_len_buf).await?;
let headers_len = u16::from_be_bytes(headers_len_buf) as usize;
// Read headers
let mut headers_buf = vec![0u8; headers_len];
read_half.read_exact(&mut headers_buf).await?;
// Read payload length (4 bytes)
let mut len_buf = [0u8; 4];
read_half.read_exact(&mut len_buf).await?;
......@@ -293,9 +302,12 @@ impl SharedTcpServer {
read_half.read_exact(&mut payload_buf).await?;
// Reconstruct the full message buffer for decoding using BytesMut
let mut full_msg = BytesMut::with_capacity(2 + path_len + 4 + payload_len);
let mut full_msg =
BytesMut::with_capacity(2 + path_len + 2 + headers_len + 4 + payload_len);
full_msg.extend_from_slice(&path_len_buf);
full_msg.extend_from_slice(&path_buf);
full_msg.extend_from_slice(&headers_len_buf);
full_msg.extend_from_slice(&headers_buf);
full_msg.extend_from_slice(&len_buf);
full_msg.extend_from_slice(&payload_buf);
......@@ -316,6 +328,7 @@ impl SharedTcpServer {
};
let endpoint_path = request_msg.endpoint_path;
let headers = request_msg.headers;
let payload = request_msg.payload;
// Look up handler (lock-free read with DashMap)
......@@ -361,15 +374,18 @@ impl SharedTcpServer {
tokio::spawn(async move {
tracing::trace!(instance_id, "handling TCP request");
// Create span with trace context from headers
let span = crate::logging::make_handle_payload_span_from_tcp_headers(
&headers,
&component_name,
&endpoint_name,
&namespace,
instance_id,
);
let result = service_handler
.handle_payload(payload)
.instrument(tracing::info_span!(
"handle_payload",
component = component_name.as_str(),
endpoint = endpoint_name.as_str(),
namespace = namespace.as_str(),
instance_id = instance_id,
))
.instrument(span)
.await;
match result {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment