Unverified Commit bd4fe1a7 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files
parent 1954fcfa
...@@ -401,6 +401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -401,6 +401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de45108900e1f9b9242f7f2e254aa3e2c029c921c258fe9e6b4217eeebd54288" checksum = "de45108900e1f9b9242f7f2e254aa3e2c029c921c258fe9e6b4217eeebd54288"
dependencies = [ dependencies = [
"axum-core 0.5.2", "axum-core 0.5.2",
"axum-macros",
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
"futures-util", "futures-util",
...@@ -468,6 +469,17 @@ dependencies = [ ...@@ -468,6 +469,17 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "axum-macros"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]] [[package]]
name = "backoff" name = "backoff"
version = "0.4.0" version = "0.4.0"
...@@ -1889,6 +1901,7 @@ dependencies = [ ...@@ -1889,6 +1901,7 @@ dependencies = [
"tokio-util", "tokio-util",
"toktrie 1.1.0", "toktrie 1.1.0",
"toktrie_hf_tokenizers 1.1.0", "toktrie_hf_tokenizers 1.1.0",
"tower-http",
"tracing", "tracing",
"unicode-segmentation", "unicode-segmentation",
"url", "url",
...@@ -1975,6 +1988,7 @@ dependencies = [ ...@@ -1975,6 +1988,7 @@ dependencies = [
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"url", "url",
...@@ -7685,6 +7699,7 @@ dependencies = [ ...@@ -7685,6 +7699,7 @@ dependencies = [
"tower 0.5.2", "tower 0.5.2",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing",
] ]
[[package]] [[package]]
......
...@@ -60,7 +60,8 @@ thiserror = { version = "2.0.11" } ...@@ -60,7 +60,8 @@ thiserror = { version = "2.0.11" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1" } tokio-stream = { version = "0.1" }
tokio-util = { version = "0.7", features = ["codec", "net"] } tokio-util = { version = "0.7", features = ["codec", "net"] }
axum = { version = "0.8" } tower-http = {version = "0.6", features=["trace"]}
axum = { version = "0.8" , features = ["macros"]}
tracing = { version = "0.1" } tracing = { version = "0.1" }
tracing-subscriber = { version = "0.3", features = ["env-filter", "local-time", "json"] } tracing-subscriber = { version = "0.3", features = ["env-filter", "local-time", "json"] }
validator = { version = "0.20.0", features = ["derive"] } validator = { version = "0.20.0", features = ["derive"] }
...@@ -68,7 +69,6 @@ uuid = { version = "1.17", features = ["v4", "serde"] } ...@@ -68,7 +69,6 @@ uuid = { version = "1.17", features = ["v4", "serde"] }
url = {version = "2.5", features = ["serde"]} url = {version = "2.5", features = ["serde"]}
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] } xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
[profile.dev.package] [profile.dev.package]
insta.opt-level = 3 insta.opt-level = 3
......
...@@ -323,6 +323,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -323,6 +323,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de45108900e1f9b9242f7f2e254aa3e2c029c921c258fe9e6b4217eeebd54288" checksum = "de45108900e1f9b9242f7f2e254aa3e2c029c921c258fe9e6b4217eeebd54288"
dependencies = [ dependencies = [
"axum-core", "axum-core",
"axum-macros",
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
"futures-util", "futures-util",
...@@ -370,6 +371,17 @@ dependencies = [ ...@@ -370,6 +371,17 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "axum-macros"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]] [[package]]
name = "backoff" name = "backoff"
version = "0.4.0" version = "0.4.0"
...@@ -1177,6 +1189,7 @@ dependencies = [ ...@@ -1177,6 +1189,7 @@ dependencies = [
"tokio-util", "tokio-util",
"toktrie", "toktrie",
"toktrie_hf_tokenizers", "toktrie_hf_tokenizers",
"tower-http",
"tracing", "tracing",
"unicode-segmentation", "unicode-segmentation",
"url", "url",
...@@ -1252,6 +1265,7 @@ dependencies = [ ...@@ -1252,6 +1265,7 @@ dependencies = [
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"url", "url",
...@@ -4888,6 +4902,7 @@ dependencies = [ ...@@ -4888,6 +4902,7 @@ dependencies = [
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing",
] ]
[[package]] [[package]]
......
...@@ -100,6 +100,8 @@ unicode-segmentation = "1.12" ...@@ -100,6 +100,8 @@ unicode-segmentation = "1.12"
# http-service # http-service
axum = { workspace = true } axum = { workspace = true }
tower-http = {workspace = true}
# tokenizers # tokenizers
tokenizers = { version = "0.21.4", default-features = false, features = [ tokenizers = { version = "0.21.4", default-features = false, features = [
......
...@@ -39,6 +39,8 @@ use crate::protocols::openai::{ ...@@ -39,6 +39,8 @@ use crate::protocols::openai::{
}; };
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::Annotated; use crate::types::Annotated;
use dynamo_runtime::logging::get_distributed_tracing_context;
use tracing::Instrument;
pub const DYNAMO_REQUEST_ID_HEADER: &str = "x-dynamo-request-id"; pub const DYNAMO_REQUEST_ID_HEADER: &str = "x-dynamo-request-id";
...@@ -132,6 +134,13 @@ impl From<HttpError> for ErrorMessage { ...@@ -132,6 +134,13 @@ impl From<HttpError> for ErrorMessage {
/// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present /// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present
fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String { fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String {
// Try to get request id from trace context
if let Some(trace_context) = get_distributed_tracing_context() {
if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
return x_dynamo_request_id;
}
}
// Try to get the request ID from the primary source // Try to get the request ID from the primary source
if let Some(primary) = primary { if let Some(primary) = primary {
if let Ok(uuid) = uuid::Uuid::parse_str(primary) { if let Ok(uuid) = uuid::Uuid::parse_str(primary) {
...@@ -181,7 +190,7 @@ async fn handler_completions( ...@@ -181,7 +190,7 @@ async fn handler_completions(
// possibly long running task // possibly long running task
// if this returns a streaming response, the stream handle will be armed and captured by the response stream // if this returns a streaming response, the stream handle will be armed and captured by the response stream
let response = tokio::spawn(completions(state, request, stream_handle)) let response = tokio::spawn(completions(state, request, stream_handle).in_current_span())
.await .await
.map_err(|e| { .map_err(|e| {
ErrorMessage::internal_server_error(&format!( ErrorMessage::internal_server_error(&format!(
...@@ -371,7 +380,8 @@ async fn handler_chat_completions( ...@@ -371,7 +380,8 @@ async fn handler_chat_completions(
// create the connection handles // create the connection handles
let (mut connection_handle, stream_handle) = create_connection_monitor(context.clone()).await; let (mut connection_handle, stream_handle) = create_connection_monitor(context.clone()).await;
let response = tokio::spawn(chat_completions(state, template, request, stream_handle)) let response =
tokio::spawn(chat_completions(state, template, request, stream_handle).in_current_span())
.await .await
.map_err(|e| { .map_err(|e| {
ErrorMessage::internal_server_error(&format!( ErrorMessage::internal_server_error(&format!(
...@@ -395,7 +405,6 @@ async fn handler_chat_completions( ...@@ -395,7 +405,6 @@ async fn handler_chat_completions(
/// ///
/// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For /// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For
/// non-streaming requests, we will fold the stream into a single response as part of this handler. /// non-streaming requests, we will fold the stream into a single response as part of this handler.
#[tracing::instrument(level = "debug", skip_all, fields(request_id = %request.id()))]
async fn chat_completions( async fn chat_completions(
state: Arc<service_v2::State>, state: Arc<service_v2::State>,
template: Option<RequestTemplate>, template: Option<RequestTemplate>,
...@@ -597,7 +606,7 @@ async fn handler_responses( ...@@ -597,7 +606,7 @@ async fn handler_responses(
// create the connection handles // create the connection handles
let (mut connection_handle, _stream_handle) = create_connection_monitor(context.clone()).await; let (mut connection_handle, _stream_handle) = create_connection_monitor(context.clone()).await;
let response = tokio::spawn(responses(state, template, request)) let response = tokio::spawn(responses(state, template, request).in_current_span())
.await .await
.map_err(|e| { .map_err(|e| {
ErrorMessage::internal_server_error(&format!( ErrorMessage::internal_server_error(&format!(
......
...@@ -12,9 +12,11 @@ use crate::discovery::ModelManager; ...@@ -12,9 +12,11 @@ use crate::discovery::ModelManager;
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::logging::make_request_span;
use dynamo_runtime::transports::etcd; use dynamo_runtime::transports::etcd;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer;
/// HTTP service shared state /// HTTP service shared state
pub struct State { pub struct State {
...@@ -230,6 +232,9 @@ impl HttpServiceConfigBuilder { ...@@ -230,6 +232,9 @@ impl HttpServiceConfigBuilder {
all_docs.extend(route_docs); all_docs.extend(route_docs);
} }
// Add span for tracing
router = router.layer(TraceLayer::new_for_http().make_span_with(make_request_span));
Ok(HttpService { Ok(HttpService {
state, state,
router, router,
......
...@@ -42,6 +42,7 @@ serde_json = { workspace = true } ...@@ -42,6 +42,7 @@ serde_json = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
tokio-stream = { workspace = true } tokio-stream = { workspace = true }
tokio-util = { workspace = true } tokio-util = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
......
...@@ -121,7 +121,10 @@ impl EndpointConfigBuilder { ...@@ -121,7 +121,10 @@ impl EndpointConfigBuilder {
// launch in primary runtime // launch in primary runtime
let task = tokio::spawn(push_endpoint.start( let task = tokio::spawn(push_endpoint.start(
service_endpoint, service_endpoint,
endpoint.component.namespace.name.clone(),
endpoint.component.name.clone(),
endpoint.name.clone(), endpoint.name.clone(),
lease_id,
endpoint.drt().system_health.clone(), endpoint.drt().system_health.clone(),
)); ));
......
...@@ -115,8 +115,8 @@ impl SystemHealth { ...@@ -115,8 +115,8 @@ impl SystemHealth {
self.system_health = status; self.system_health = status;
} }
pub fn set_endpoint_health_status(&mut self, endpoint: String, status: HealthStatus) { pub fn set_endpoint_health_status(&mut self, endpoint: &str, status: HealthStatus) {
self.endpoint_health.insert(endpoint, status); self.endpoint_health.insert(endpoint.to_string(), status);
} }
/// Returns the overall health status and endpoint health statuses /// Returns the overall health status and endpoint health statuses
......
...@@ -48,11 +48,15 @@ use tracing_subscriber::EnvFilter; ...@@ -48,11 +48,15 @@ use tracing_subscriber::EnvFilter;
use tracing_subscriber::{filter::Directive, fmt}; use tracing_subscriber::{filter::Directive, fmt};
use crate::config::{disable_ansi_logging, jsonl_logging_enabled}; use crate::config::{disable_ansi_logging, jsonl_logging_enabled};
use async_nats::{HeaderMap, HeaderValue};
use axum::extract::FromRequestParts; use axum::extract::FromRequestParts;
use axum::http;
use axum::http::request::Parts; use axum::http::request::Parts;
use axum::http::Request;
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::time::Instant; use std::time::Instant;
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
use tracing::field::Field; use tracing::field::Field;
use tracing::span; use tracing::span;
use tracing::Id; use tracing::Id;
...@@ -130,76 +134,130 @@ pub fn is_valid_span_id(span_id: &str) -> bool { ...@@ -130,76 +134,130 @@ pub fn is_valid_span_id(span_id: &str) -> bool {
pub struct DistributedTraceIdLayer; pub struct DistributedTraceIdLayer;
#[derive(Clone)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedTraceContext { pub struct DistributedTraceContext {
trace_id: String, pub trace_id: String,
span_id: String, pub span_id: String,
parent_id: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
tracestate: Option<String>, pub parent_id: Option<String>,
start: Instant, #[serde(skip_serializing_if = "Option::is_none")]
pub tracestate: Option<String>,
#[serde(skip)]
start: Option<Instant>,
#[serde(skip)]
end: Option<Instant>, end: Option<Instant>,
x_request_id: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
pub x_request_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x_dynamo_request_id: Option<String>,
}
impl DistributedTraceContext {
/// Create a traceparent string from the context
pub fn create_traceparent(&self) -> String {
format!("00-{}-{}-01", self.trace_id, self.span_id)
}
}
/// Parse a traceparent string into its components
pub fn parse_traceparent(traceparent: &str) -> (Option<String>, Option<String>) {
let pieces: Vec<_> = traceparent.split('-').collect();
if pieces.len() != 4 {
return (None, None);
}
let trace_id = pieces[1];
let parent_id = pieces[2];
if !is_valid_trace_id(trace_id) || !is_valid_span_id(parent_id) {
return (None, None);
}
(Some(trace_id.to_string()), Some(parent_id.to_string()))
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, Default)]
pub struct TraceParent { pub struct TraceParent {
pub trace_id: Option<String>, pub trace_id: Option<String>,
pub parent_id: Option<String>, pub parent_id: Option<String>,
pub tracestate: Option<String>, pub tracestate: Option<String>,
pub x_request_id: Option<String>, pub x_request_id: Option<String>,
pub x_dynamo_request_id: Option<String>,
} }
impl<S> FromRequestParts<S> for TraceParent pub trait GenericHeaders {
where fn get(&self, key: &str) -> Option<&str>;
S: Send + Sync, }
{
type Rejection = Infallible; impl GenericHeaders for async_nats::HeaderMap {
fn get(&self, key: &str) -> Option<&str> {
async_nats::HeaderMap::get(self, key).map(|value| value.as_str())
}
}
impl GenericHeaders for http::HeaderMap {
fn get(&self, key: &str) -> Option<&str> {
http::HeaderMap::get(self, key).and_then(|value| value.to_str().ok())
}
}
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { impl TraceParent {
pub fn from_headers<H: GenericHeaders>(headers: &H) -> TraceParent {
let mut trace_id = None; let mut trace_id = None;
let mut parent_id = None; let mut parent_id = None;
let mut tracestate = None; let mut tracestate = None;
if let Some(header_value) = parts.headers.get("traceparent") { let mut x_request_id = None;
if let Ok(header_str) = header_value.to_str() { let mut x_dynamo_request_id = None;
let pieces: Vec<_> = header_str.split('-').collect();
if pieces.len() == 4 { if let Some(header_value) = headers.get("traceparent") {
let candidate_trace_id = pieces[1]; (trace_id, parent_id) = parse_traceparent(header_value);
let candidate_parent_id = pieces[2];
if is_valid_trace_id(candidate_trace_id)
&& is_valid_span_id(candidate_parent_id)
{
trace_id = Some(candidate_trace_id.to_string());
parent_id = Some(candidate_parent_id.to_string());
} else {
tracing::debug!("Invalid traceparent header: {}", header_str);
}
}
}
} }
if let Some(header_value) = parts.headers.get("tracestate") { if let Some(header_value) = headers.get("x-request-id") {
if let Ok(header_str) = header_value.to_str() { x_request_id = Some(header_value.to_string());
tracestate = Some(header_str.to_string());
} }
if let Some(header_value) = headers.get("tracestate") {
tracestate = Some(header_value.to_string());
} }
// Extract X-Request-ID or x-request-id (case-insensitive) if let Some(header_value) = headers.get("x-dynamo-request-id") {
let x_request_id = parts x_dynamo_request_id = Some(header_value.to_string());
.headers }
.get("x-request-id")
.and_then(|val| val.to_str().ok())
.map(|s| s.to_string());
Ok(TraceParent { // Validate UUID format
let x_dynamo_request_id =
x_dynamo_request_id.filter(|id| uuid::Uuid::parse_str(id).is_ok());
TraceParent {
trace_id, trace_id,
parent_id, parent_id,
tracestate, tracestate,
x_request_id, x_request_id,
}) x_dynamo_request_id,
}
} }
} }
// Takes Axum request and returning a span
pub fn make_request_span<B>(req: &Request<B>) -> Span {
let method = req.method();
let uri = req.uri();
let version = format!("{:?}", req.version());
let trace_parent = TraceParent::from_headers(req.headers());
tracing::info_span!(
"http-request",
method = %method,
uri = %uri,
version = %version,
trace_id = trace_parent.trace_id,
parent_id = trace_parent.parent_id,
x_request_id = trace_parent.x_request_id,
x_dynamo_request_id = trace_parent.x_dynamo_request_id,
)
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct FieldVisitor { pub struct FieldVisitor {
pub fields: HashMap<String, String>, pub fields: HashMap<String, String>,
...@@ -241,6 +299,7 @@ where ...@@ -241,6 +299,7 @@ where
let mut parent_id: Option<String> = None; let mut parent_id: Option<String> = None;
let mut span_id: Option<String> = None; let mut span_id: Option<String> = None;
let mut x_request_id: Option<String> = None; let mut x_request_id: Option<String> = None;
let mut x_dynamo_request_id: Option<String> = None;
let mut tracestate: Option<String> = None; let mut tracestate: Option<String> = None;
let mut visitor = FieldVisitor::default(); let mut visitor = FieldVisitor::default();
attrs.record(&mut visitor); attrs.record(&mut visitor);
...@@ -277,6 +336,10 @@ where ...@@ -277,6 +336,10 @@ where
x_request_id = Some(x_request_id_input.to_string()); x_request_id = Some(x_request_id_input.to_string());
} }
if let Some(x_request_id_input) = visitor.fields.get("x_dynamo_request_id") {
x_dynamo_request_id = Some(x_request_id_input.to_string());
}
if parent_id.is_none() { if parent_id.is_none() {
if let Some(parent_span_id) = ctx.current_span().id() { if let Some(parent_span_id) = ctx.current_span().id() {
if let Some(parent_span) = ctx.span(parent_span_id) { if let Some(parent_span) = ctx.span(parent_span_id) {
...@@ -312,9 +375,10 @@ where ...@@ -312,9 +375,10 @@ where
span_id: span_id.expect("Span ID must be set"), span_id: span_id.expect("Span ID must be set"),
parent_id, parent_id,
tracestate, tracestate,
start: Instant::now(), start: Some(Instant::now()),
end: None, end: None,
x_request_id, x_request_id,
x_dynamo_request_id,
}); });
} }
} }
...@@ -364,16 +428,17 @@ fn setup_logging() { ...@@ -364,16 +428,17 @@ fn setup_logging() {
#[cfg(not(feature = "tokio-console"))] #[cfg(not(feature = "tokio-console"))]
fn setup_logging() { fn setup_logging() {
let filter_layer = filters(load_config()); let fmt_filter_layer = filters(load_config());
let trace_filter_layer = filters(load_config());
if jsonl_logging_enabled() { if jsonl_logging_enabled() {
let l = fmt::layer() let l = fmt::layer()
.with_ansi(false) .with_ansi(false)
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
.event_format(CustomJsonFormatter::new()) .event_format(CustomJsonFormatter::new())
.with_writer(std::io::stderr) .with_writer(std::io::stderr)
.with_filter(filter_layer); .with_filter(fmt_filter_layer);
tracing_subscriber::registry() tracing_subscriber::registry()
.with(DistributedTraceIdLayer) .with(DistributedTraceIdLayer.with_filter(trace_filter_layer))
.with(l) .with(l)
.init(); .init();
} else { } else {
...@@ -381,7 +446,7 @@ fn setup_logging() { ...@@ -381,7 +446,7 @@ fn setup_logging() {
.with_ansi(!disable_ansi_logging()) .with_ansi(!disable_ansi_logging())
.event_format(fmt::format().compact().with_timer(TimeFormatter::new())) .event_format(fmt::format().compact().with_timer(TimeFormatter::new()))
.with_writer(std::io::stderr) .with_writer(std::io::stderr)
.with_filter(filter_layer); .with_filter(fmt_filter_layer);
tracing_subscriber::registry().with(l).init(); tracing_subscriber::registry().with(l).init();
} }
} }
...@@ -616,6 +681,15 @@ where ...@@ -616,6 +681,15 @@ where
} else { } else {
visitor.fields.remove("x_request_id"); visitor.fields.remove("x_request_id");
} }
if let Some(x_dynamo_request_id) = tracing_context.x_dynamo_request_id.clone() {
visitor.fields.insert(
"x_dynamo_request_id".to_string(),
serde_json::Value::String(x_dynamo_request_id),
);
} else {
visitor.fields.remove("x_dynamo_request_id");
}
} else { } else {
tracing::error!( tracing::error!(
"Distributed Trace Context not found, falling back to internal ids" "Distributed Trace Context not found, falling back to internal ids"
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
use crate::config::HealthStatus; use crate::config::HealthStatus;
use crate::logging::make_request_span;
use crate::logging::TraceParent; use crate::logging::TraceParent;
use crate::metrics::MetricsRegistry; use crate::metrics::MetricsRegistry;
use crate::traits::DistributedRuntimeProvider; use crate::traits::DistributedRuntimeProvider;
...@@ -25,8 +26,8 @@ use std::sync::OnceLock; ...@@ -25,8 +26,8 @@ use std::sync::OnceLock;
use std::time::Instant; use std::time::Instant;
use tokio::{net::TcpListener, task::JoinHandle}; use tokio::{net::TcpListener, task::JoinHandle};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing; use tower_http::trace::DefaultMakeSpan;
use tracing::Instrument; use tower_http::trace::TraceLayer;
/// Metrics server information containing socket address and handle /// Metrics server information containing socket address and handle
#[derive(Debug)] #[derive(Debug)]
...@@ -155,36 +156,28 @@ pub async fn spawn_metrics_server( ...@@ -155,36 +156,28 @@ pub async fn spawn_metrics_server(
&health_path, &health_path,
get({ get({
let state = Arc::clone(&server_state); let state = Arc::clone(&server_state);
move |tracing_ctx| health_handler(state, "health", tracing_ctx) move || health_handler(state)
}), }),
) )
.route( .route(
&live_path, &live_path,
get({ get({
let state = Arc::clone(&server_state); let state = Arc::clone(&server_state);
move |tracing_ctx| health_handler(state, "live", tracing_ctx) move || health_handler(state)
}), }),
) )
.route( .route(
"/metrics", "/metrics",
get({ get({
let state = Arc::clone(&server_state); let state = Arc::clone(&server_state);
move |tracing_ctx| metrics_handler(state, "metrics", tracing_ctx) move || metrics_handler(state)
}), }),
) )
.fallback(|tracing_ctx: TraceParent| { .fallback(|| async {
async {
tracing::info!("[fallback handler] called"); tracing::info!("[fallback handler] called");
(StatusCode::NOT_FOUND, "Route not found").into_response() (StatusCode::NOT_FOUND, "Route not found").into_response()
} })
.instrument(tracing::trace_span!( .layer(TraceLayer::new_for_http().make_span_with(make_request_span));
"fallback handler",
trace_id = tracing_ctx.trace_id,
parent_id = tracing_ctx.parent_id,
x_request_id = tracing_ctx.x_request_id,
tracestate = tracing_ctx.tracestate
))
});
let address = format!("{}:{}", host, port); let address = format!("{}:{}", host, port);
tracing::info!("[spawn_metrics_server] binding to: {}", address); tracing::info!("[spawn_metrics_server] binding to: {}", address);
...@@ -220,16 +213,7 @@ pub async fn spawn_metrics_server( ...@@ -220,16 +213,7 @@ pub async fn spawn_metrics_server(
} }
/// Health handler /// Health handler
#[tracing::instrument(skip_all, level="trace", fields(route= %route, async fn health_handler(state: Arc<MetricsServerState>) -> impl IntoResponse {
trace_id = trace_parent.trace_id,
parent_id = trace_parent.parent_id,
x_request_id= trace_parent.x_request_id,
tracestate= trace_parent.tracestate))]
async fn health_handler(
state: Arc<MetricsServerState>,
route: &'static str, // Used for tracing only
trace_parent: TraceParent, // Used for tracing only
) -> impl IntoResponse {
let (mut healthy, endpoints) = state let (mut healthy, endpoints) = state
.drt() .drt()
.system_health .system_health
...@@ -264,16 +248,7 @@ async fn health_handler( ...@@ -264,16 +248,7 @@ async fn health_handler(
} }
/// Metrics handler with DistributedRuntime uptime /// Metrics handler with DistributedRuntime uptime
#[tracing::instrument(skip_all, level="trace", fields(route= %route, async fn metrics_handler(state: Arc<MetricsServerState>) -> impl IntoResponse {
trace_id = trace_parent.trace_id,
parent_id = trace_parent.parent_id,
x_request_id = trace_parent.x_request_id,
tracestate = trace_parent.tracestate))]
async fn metrics_handler(
state: Arc<MetricsServerState>,
route: &'static str, // Used for tracing only
trace_parent: TraceParent, // Used for tracing only
) -> impl IntoResponse {
// Update the uptime gauge with current value // Update the uptime gauge with current value
state.update_uptime_gauge(); state.update_uptime_gauge();
......
...@@ -14,11 +14,15 @@ ...@@ -14,11 +14,15 @@
// limitations under the License. // limitations under the License.
use async_nats::client::Client; use async_nats::client::Client;
use async_nats::{HeaderMap, HeaderValue};
use tracing as log; use tracing as log;
use super::*; use super::*;
use crate::logging::get_distributed_tracing_context;
use crate::logging::DistributedTraceContext;
use crate::{protocols::maybe_error::MaybeError, Result}; use crate::{protocols::maybe_error::MaybeError, Result};
use tokio_stream::{wrappers::ReceiverStream, StreamExt, StreamNotifyClose}; use tokio_stream::{wrappers::ReceiverStream, StreamExt, StreamNotifyClose};
use tracing::Instrument;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -148,11 +152,29 @@ where ...@@ -148,11 +152,29 @@ where
log::trace!(request_id, "enqueueing two-part message to nats"); log::trace!(request_id, "enqueueing two-part message to nats");
// Insert Trace Context into Headers
// Enables span to be created in push_endpoint before
// payload is parsed
let mut headers = HeaderMap::new();
if let Some(trace_context) = get_distributed_tracing_context() {
headers.insert("traceparent", trace_context.create_traceparent());
if let Some(tracestate) = trace_context.tracestate {
headers.insert("tracestate", tracestate);
}
if let Some(x_request_id) = trace_context.x_request_id {
headers.insert("x-request-id", x_request_id);
}
if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
headers.insert("x-dynamo-request-id", x_dynamo_request_id);
}
}
// we might need to add a timeout on this if there is no subscriber to the subject; however, I think nats // we might need to add a timeout on this if there is no subscriber to the subject; however, I think nats
// will handle this for us // will handle this for us
let _response = self let _response = self
.req_transport .req_transport
.request(address.to_string(), buffer) .request_with_headers(address.to_string(), headers, buffer)
.await?; .await?;
log::trace!(request_id, "awaiting transport handshake"); log::trace!(request_id, "awaiting transport handshake");
......
...@@ -17,6 +17,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; ...@@ -17,6 +17,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
use super::*; use super::*;
use crate::config::HealthStatus; use crate::config::HealthStatus;
use crate::logging::TraceParent;
use crate::protocols::LeaseId; use crate::protocols::LeaseId;
use crate::SystemHealth; use crate::SystemHealth;
use anyhow::Result; use anyhow::Result;
...@@ -26,6 +27,7 @@ use std::collections::HashMap; ...@@ -26,6 +27,7 @@ use std::collections::HashMap;
use std::sync::Mutex; use std::sync::Mutex;
use tokio::sync::Notify; use tokio::sync::Notify;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::Instrument;
#[derive(Builder)] #[derive(Builder)]
pub struct PushEndpoint { pub struct PushEndpoint {
...@@ -46,18 +48,24 @@ impl PushEndpoint { ...@@ -46,18 +48,24 @@ impl PushEndpoint {
pub async fn start( pub async fn start(
self, self,
endpoint: Endpoint, endpoint: Endpoint,
namespace: String,
component_name: String,
endpoint_name: String, endpoint_name: String,
instance_id: i64,
system_health: Arc<Mutex<SystemHealth>>, system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> { ) -> Result<()> {
let mut endpoint = endpoint; let mut endpoint = endpoint;
let inflight = Arc::new(AtomicU64::new(0)); let inflight = Arc::new(AtomicU64::new(0));
let notify = Arc::new(Notify::new()); let notify = Arc::new(Notify::new());
let component_name_local: Arc<String> = Arc::from(component_name);
let endpoint_name_local: Arc<String> = Arc::from(endpoint_name);
let namespace_local: Arc<String> = Arc::from(namespace);
system_health system_health
.lock() .lock()
.unwrap() .unwrap()
.set_endpoint_health_status(endpoint_name.clone(), HealthStatus::Ready); .set_endpoint_health_status(endpoint_name_local.as_str(), HealthStatus::Ready);
loop { loop {
let req = tokio::select! { let req = tokio::select! {
...@@ -85,19 +93,47 @@ impl PushEndpoint { ...@@ -85,19 +93,47 @@ impl PushEndpoint {
} }
let ingress = self.service_handler.clone(); let ingress = self.service_handler.clone();
let worker_id = "".to_string(); let endpoint_name: Arc<String> = Arc::clone(&endpoint_name_local);
let component_name: Arc<String> = Arc::clone(&component_name_local);
let namespace: Arc<String> = Arc::clone(&namespace_local);
// increment the inflight counter // increment the inflight counter
inflight.fetch_add(1, Ordering::SeqCst); inflight.fetch_add(1, Ordering::SeqCst);
let inflight_clone = inflight.clone(); let inflight_clone = inflight.clone();
let notify_clone = notify.clone(); let notify_clone = notify.clone();
// Handle headers here for tracing
let mut traceparent = TraceParent::default();
if let Some(headers) = req.message.headers.as_ref() {
traceparent = TraceParent::from_headers(headers);
}
tokio::spawn(async move { tokio::spawn(async move {
tracing::trace!(worker_id, "handling new request"); tracing::trace!(instance_id, "handling new request");
let result = ingress.handle_payload(req.message.payload).await; let result = ingress
.handle_payload(req.message.payload)
.instrument(
// Create span with trace ids as set
// in headers.
tracing::info_span!(
"handle_payload",
component = component_name.as_ref(),
endpoint = endpoint_name.as_ref(),
namespace = namespace.as_ref(),
instance_id = instance_id,
trace_id = traceparent.trace_id,
parent_id = traceparent.parent_id,
x_request_id = traceparent.x_request_id,
x_dynamo_request_id = traceparent.x_dynamo_request_id,
tracestate = traceparent.tracestate
),
)
.await;
match result { match result {
Ok(_) => { Ok(_) => {
tracing::trace!(worker_id, "request handled successfully"); tracing::trace!(instance_id, "request handled successfully");
} }
Err(e) => { Err(e) => {
tracing::warn!("Failed to handle request: {:?}", e); tracing::warn!("Failed to handle request: {:?}", e);
...@@ -116,7 +152,7 @@ impl PushEndpoint { ...@@ -116,7 +152,7 @@ impl PushEndpoint {
system_health system_health
.lock() .lock()
.unwrap() .unwrap()
.set_endpoint_health_status(endpoint_name.clone(), HealthStatus::NotReady); .set_endpoint_health_status(endpoint_name_local.as_str(), HealthStatus::NotReady);
// await for all inflight requests to complete if graceful shutdown // await for all inflight requests to complete if graceful shutdown
if self.graceful_shutdown { if self.graceful_shutdown {
......
...@@ -18,6 +18,8 @@ use crate::protocols::maybe_error::MaybeError; ...@@ -18,6 +18,8 @@ use crate::protocols::maybe_error::MaybeError;
use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge}; use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use tracing::info_span;
use tracing::Instrument;
/// Metrics configuration for profiling work handlers /// Metrics configuration for profiling work handlers
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment