Unverified Commit cbbaa6b7 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(sglang): use external_trace_header API for distributed tracing (#5346)


Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent 1a8dcacd
......@@ -2,8 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import json
import logging
import random
import socket
......@@ -12,7 +10,6 @@ from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional, Tuple
import sglang as sgl
from sglang.srt.tracing import trace as sglang_trace
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Component, Context
......@@ -143,38 +140,20 @@ class BaseWorkerHandler(ABC):
return bootstrap_host, bootstrap_port
def _propagate_trace_context_to_sglang(
self, context: Context, bootstrap_room: int = 0
):
"""Propagate Dynamo's trace context to SGLang for distributed tracing. SGLang expects a certain
format derived by loooking at https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/tracing/trace.py
in the to_dict() method.
def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]:
"""Get trace header dict for passing to SGLang's external_trace_header parameter.
Args:
context: Dynamo Context object containing trace information.
bootstrap_room: Bootstrap room ID (0 for aggregated, actual room for disaggregated).
Returns:
Dict with traceparent header if trace context available, None otherwise.
"""
trace_id = context.trace_id
span_id = context.span_id
if not trace_id or not span_id:
return
# Build trace context for SGLang
trace_context = {
str(bootstrap_room): {
"root_span": {"traceparent": f"00-{trace_id}-{span_id}-01"},
"prev_span": {
"span_id": int(span_id, 16),
"trace_id": int(trace_id, 16),
},
}
}
# Encode and propagate
base64_context = base64.b64encode(
json.dumps(trace_context, ensure_ascii=False).encode("utf-8")
).decode("utf-8")
sglang_trace.trace_set_remote_propagate_context(base64_context)
return None
return {"traceparent": f"00-{trace_id}-{span_id}-01"}
async def _handle_cancellation(
self, request_id_future: asyncio.Future, context: Context
......
......@@ -119,10 +119,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
f"room={bootstrap_info['bootstrap_room']}"
)
if self.enable_trace:
self._propagate_trace_context_to_sglang(
context, bootstrap_info["bootstrap_room"]
)
trace_header = (
self._get_trace_header(context) if self.enable_trace else None
)
decode = await self.engine.async_generate(
**input_param,
......@@ -131,6 +130,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
bootstrap_host=bootstrap_info["bootstrap_host"],
bootstrap_port=bootstrap_info["bootstrap_port"],
bootstrap_room=bootstrap_info["bootstrap_room"],
external_trace_header=trace_header,
rid=trace_id,
)
......@@ -141,13 +141,15 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async for out in self._process_text_stream(decode, context):
yield out
else:
if self.enable_trace:
self._propagate_trace_context_to_sglang(context)
trace_header = (
self._get_trace_header(context) if self.enable_trace else None
)
agg = await self.engine.async_generate(
**input_param,
sampling_params=sampling_params,
stream=True,
external_trace_header=trace_header,
rid=trace_id,
)
if self.skip_tokenizer_init:
......
......@@ -113,9 +113,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
input_param = self._get_input_param(inner_request)
# Propagate trace context to SGLang
if self.enable_trace:
self._propagate_trace_context_to_sglang(context, bootstrap_room)
trace_header = self._get_trace_header(context) if self.enable_trace else None
results = await self.engine.async_generate(
**input_param,
......@@ -124,6 +122,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room,
external_trace_header=trace_header,
rid=trace_id,
)
......
......@@ -8,6 +8,7 @@ use futures::StreamExt;
use rand::Rng;
use tokio::sync::{OwnedSemaphorePermit, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use dynamo_runtime::{
component::Endpoint,
......@@ -265,10 +266,14 @@ impl PrefillRouter {
InnerPrefillRouter::KvRouter(r) => r,
_ => return None,
};
match kv_router
.chooser
.find_best_match(None, &req.token_ids, None, false)
.await
match async {
kv_router
.chooser
.find_best_match(None, &req.token_ids, None, false)
.await
}
.instrument(tracing::info_span!("kv_find_best_match"))
.await
{
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
Err(_) => return None,
......@@ -405,19 +410,29 @@ impl PrefillRouter {
phase_permit: OwnedSemaphorePermit,
) {
let router = self.prefill_router.get().cloned();
tokio::spawn(async move {
match Self::execute_prefill(router, prefill_request, target_worker, Some(phase_permit))
// Capture current span to propagate trace context to the spawned task
let span = tracing::Span::current();
tokio::spawn(
async move {
match Self::execute_prefill(
router,
prefill_request,
target_worker,
Some(phase_permit),
)
.await
{
Ok(_) => {
tracing::debug!("Prefill background task completed");
}
Err(e) => {
tracing::warn!("Prefill background task error: {e:?}");
{
Ok(_) => {
tracing::debug!("Prefill background task completed");
}
Err(e) => {
tracing::warn!("Prefill background task error: {e:?}");
}
}
}
});
.instrument(span),
);
}
/// Call the prefill router and extract structured prefill result and worker ID.
......@@ -491,47 +506,51 @@ impl
.as_ref()
.and_then(|r| r.prefill_worker_id);
let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
{
// Bootstrap optimization path: spawn prefill in background
// We successfully used the peeked worker, so we must now advance the router state
// to ensure the next request gets a different worker.
if !self.router_mode.is_kv_routing()
&& let Some(router) = self.prefill_router.get()
let prefill_result = async {
if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
{
router.select_next_worker();
}
// Bootstrap optimization path: spawn prefill in background
// We successfully used the peeked worker, so we must now advance the router state
// to ensure the next request gets a different worker.
if !self.router_mode.is_kv_routing()
&& let Some(router) = self.prefill_router.get()
{
router.select_next_worker();
}
let routing = prefill_req.routing_mut();
routing.prefill_worker_id = Some(worker_id);
routing.dp_rank = Some(dp_rank);
prefill_req.bootstrap_info = Some(bootstrap_info.clone());
let routing = prefill_req.routing_mut();
routing.prefill_worker_id = Some(worker_id);
routing.dp_rank = Some(dp_rank);
prefill_req.bootstrap_info = Some(bootstrap_info.clone());
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
// Pass phase permit to spawned task - it drops after first output (record_worker complete)
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);
// Pass phase permit to spawned task - it drops after first output (record_worker complete)
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);
Ok((None, Some(worker_id), Some(bootstrap_info)))
} else {
// Original prefill path: wait for prefill to complete
tracing::debug!("Using original prefill path");
Ok((None, Some(worker_id), Some(bootstrap_info)))
} else {
// Original prefill path: wait for prefill to complete
tracing::debug!("Using original prefill path");
// Drop the phase permit before calling call_prefill - we wait for completion
// so there's no race with set_phase(Decode) below
drop(prefill_phase_permit);
// Drop the phase permit before calling call_prefill - we wait for completion
// so there's no race with set_phase(Decode) below
drop(prefill_phase_permit);
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
};
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
}
}
.instrument(tracing::info_span!("prefill_routing"))
.await;
// Abort if cancelled during prefill
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment