Unverified Commit f2e2e935 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

feat: Add distributed tracing context support to Python bindings (#3160)


Signed-off-by: default avatarnnshah1 <neelays@nvidia.com>
Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
Co-authored-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent c03e2f6b
......@@ -3,6 +3,7 @@
// Context is a wrapper around the AsyncEngineContext to allow for Python bindings.
use dynamo_runtime::logging::DistributedTraceContext;
pub use dynamo_runtime::pipeline::AsyncEngineContext;
use dynamo_runtime::pipeline::context::Controller;
use pyo3::prelude::*;
......@@ -15,11 +16,23 @@ use std::sync::Arc;
#[pyclass]
pub struct Context {
inner: Arc<dyn AsyncEngineContext>,
trace_context: Option<DistributedTraceContext>,
}
impl Context {
pub fn new(inner: Arc<dyn AsyncEngineContext>) -> Self {
Self { inner }
pub fn new(
inner: Arc<dyn AsyncEngineContext>,
trace_context: Option<DistributedTraceContext>,
) -> Self {
Self {
inner,
trace_context,
}
}
// Get trace context for Rust-side usage
pub fn trace_context(&self) -> Option<&DistributedTraceContext> {
self.trace_context.as_ref()
}
pub fn inner(&self) -> Arc<dyn AsyncEngineContext> {
......@@ -38,6 +51,7 @@ impl Context {
};
Self {
inner: Arc::new(controller),
trace_context: None,
}
}
......@@ -74,6 +88,24 @@ impl Context {
}
})
}
// Expose trace information to Python for debugging
#[getter]
fn trace_id(&self) -> Option<String> {
self.trace_context.as_ref().map(|ctx| ctx.trace_id.clone())
}
#[getter]
fn span_id(&self) -> Option<String> {
self.trace_context.as_ref().map(|ctx| ctx.span_id.clone())
}
#[getter]
fn parent_span_id(&self) -> Option<String> {
self.trace_context
.as_ref()
.and_then(|ctx| ctx.parent_id.clone())
}
}
// PyO3 equivalent for verify if signature contains target_name
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use super::context::{Context, callable_accepts_kwarg};
use dynamo_runtime::logging::get_distributed_tracing_context;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyModule};
use pyo3::{PyAny, PyErr};
......@@ -155,6 +156,9 @@ where
let id = context.id().to_string();
tracing::trace!("processing request: {}", id);
// Capture current trace context
let current_trace_context = get_distributed_tracing_context();
// Clone the PyObject to move into the thread
// Create a channel to communicate between the Python thread and the Rust async context
......@@ -178,7 +182,9 @@ where
let stream = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let py_request = pythonize(py, &request)?;
let py_ctx = Py::new(py, Context::new(ctx_python.clone()))?;
// Create context with trace information
let py_ctx = Py::new(py, Context::new(ctx_python.clone(), current_trace_context))?;
let gen_result = if has_context {
// Pass context as a kwarg
......
......@@ -15,6 +15,7 @@ use std::path::PathBuf;
use std::time::Duration;
use std::{fmt::Display, sync::Arc};
use tokio::sync::Mutex;
use tracing::{Instrument, info_span};
use dynamo_runtime::{
self as rs, logging,
......@@ -63,6 +64,65 @@ static INIT: OnceCell<()> = OnceCell::new();
const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true);
// Helper to create client span - always emit spans for consistency
fn create_client_span(
operation: &str,
request_id: &str,
trace_context: Option<&dynamo_runtime::logging::DistributedTraceContext>,
) -> tracing::Span {
if let Some(ctx) = trace_context {
info_span!(
"client_request",
operation = operation,
request_id = request_id,
trace_id = ctx.trace_id.as_str(),
parent_id = ctx.span_id.as_str(),
x_request_id = ctx.x_request_id.as_deref().unwrap_or(""),
x_dynamo_request_id = ctx.x_dynamo_request_id.as_deref().unwrap_or(""),
tracestate = ctx.tracestate.as_deref().unwrap_or("")
)
} else {
info_span!(
"client_request",
operation = operation,
request_id = request_id,
)
}
}
// Helper to get appropriate span for instrumentation - always emit spans
fn get_span_for_context(context: &context::Context, operation: &str) -> tracing::Span {
create_client_span(operation, context.inner().id(), context.trace_context())
}
// Helper to create span for direct method with instance_id
fn get_span_for_direct_context(
context: &context::Context,
operation: &str,
instance_id: &str,
) -> tracing::Span {
if let Some(trace_ctx) = context.trace_context() {
info_span!(
"client_request",
operation = operation,
request_id = context.inner().id(),
instance_id = instance_id,
trace_id = trace_ctx.trace_id.as_str(),
parent_id = trace_ctx.span_id.as_str(),
x_request_id = trace_ctx.x_request_id.as_deref().unwrap_or(""),
x_dynamo_request_id = trace_ctx.x_dynamo_request_id.as_deref().unwrap_or(""),
tracestate = trace_ctx.tracestate.as_deref().unwrap_or("")
)
} else {
info_span!(
"client_request",
operation = operation,
request_id = context.inner().id(),
instance_id = instance_id,
)
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
......@@ -964,8 +1024,13 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.round_robin(request).await.map_err(to_pyerr)?;
let request_ctx = RsContext::with_id(request, context.inner().id().to_string());
// Always instrument with appropriate span (none if no trace context)
let span = get_span_for_context(&context, "round_robin");
let stream_future = client.round_robin(request_ctx).instrument(span);
let stream = stream_future.await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
......@@ -997,8 +1062,12 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.random(request).await.map_err(to_pyerr)?;
let request_ctx = RsContext::with_id(request, context.inner().id().to_string());
let span = get_span_for_context(&context, "random");
let stream_future = client.random(request_ctx).instrument(span);
let stream = stream_future.await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
......@@ -1031,11 +1100,13 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client
.direct(request, instance_id)
.await
.map_err(to_pyerr)?;
let request_ctx = RsContext::with_id(request, context.inner().id().to_string());
let span =
get_span_for_direct_context(&context, "direct", &instance_id.to_string());
let stream_future = client.direct(request_ctx, instance_id).instrument(span);
let stream = stream_future.await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
......@@ -1072,8 +1143,12 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.r#static(request).await.map_err(to_pyerr)?;
let request_ctx = RsContext::with_id(request, context.inner().id().to_string());
let span = get_span_for_context(&context, "static");
let stream_future = client.r#static(request_ctx).instrument(span);
let stream = stream_future.await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment