Unverified Commit cbbaa6b7 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(sglang): use external_trace_header API for distributed tracing (#5346)


Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent 1a8dcacd
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import base64
import json
import logging import logging
import random import random
import socket import socket
...@@ -12,7 +10,6 @@ from contextlib import asynccontextmanager ...@@ -12,7 +10,6 @@ from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional, Tuple from typing import Any, AsyncGenerator, Dict, Optional, Tuple
import sglang as sgl import sglang as sgl
from sglang.srt.tracing import trace as sglang_trace
from sglang.srt.utils import get_local_ip_auto from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Component, Context from dynamo._core import Component, Context
...@@ -143,38 +140,20 @@ class BaseWorkerHandler(ABC): ...@@ -143,38 +140,20 @@ class BaseWorkerHandler(ABC):
return bootstrap_host, bootstrap_port return bootstrap_host, bootstrap_port
def _propagate_trace_context_to_sglang( def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]:
self, context: Context, bootstrap_room: int = 0 """Get trace header dict for passing to SGLang's external_trace_header parameter.
):
"""Propagate Dynamo's trace context to SGLang for distributed tracing. SGLang expects a certain
format derived by loooking at https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/tracing/trace.py
in the to_dict() method.
Args: Args:
context: Dynamo Context object containing trace information. context: Dynamo Context object containing trace information.
bootstrap_room: Bootstrap room ID (0 for aggregated, actual room for disaggregated).
Returns:
Dict with traceparent header if trace context available, None otherwise.
""" """
trace_id = context.trace_id trace_id = context.trace_id
span_id = context.span_id span_id = context.span_id
if not trace_id or not span_id: if not trace_id or not span_id:
return return None
return {"traceparent": f"00-{trace_id}-{span_id}-01"}
# Build trace context for SGLang
trace_context = {
str(bootstrap_room): {
"root_span": {"traceparent": f"00-{trace_id}-{span_id}-01"},
"prev_span": {
"span_id": int(span_id, 16),
"trace_id": int(trace_id, 16),
},
}
}
# Encode and propagate
base64_context = base64.b64encode(
json.dumps(trace_context, ensure_ascii=False).encode("utf-8")
).decode("utf-8")
sglang_trace.trace_set_remote_propagate_context(base64_context)
async def _handle_cancellation( async def _handle_cancellation(
self, request_id_future: asyncio.Future, context: Context self, request_id_future: asyncio.Future, context: Context
......
...@@ -119,9 +119,8 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -119,9 +119,8 @@ class DecodeWorkerHandler(BaseWorkerHandler):
f"room={bootstrap_info['bootstrap_room']}" f"room={bootstrap_info['bootstrap_room']}"
) )
if self.enable_trace: trace_header = (
self._propagate_trace_context_to_sglang( self._get_trace_header(context) if self.enable_trace else None
context, bootstrap_info["bootstrap_room"]
) )
decode = await self.engine.async_generate( decode = await self.engine.async_generate(
...@@ -131,6 +130,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -131,6 +130,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
bootstrap_host=bootstrap_info["bootstrap_host"], bootstrap_host=bootstrap_info["bootstrap_host"],
bootstrap_port=bootstrap_info["bootstrap_port"], bootstrap_port=bootstrap_info["bootstrap_port"],
bootstrap_room=bootstrap_info["bootstrap_room"], bootstrap_room=bootstrap_info["bootstrap_room"],
external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
) )
...@@ -141,13 +141,15 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -141,13 +141,15 @@ class DecodeWorkerHandler(BaseWorkerHandler):
async for out in self._process_text_stream(decode, context): async for out in self._process_text_stream(decode, context):
yield out yield out
else: else:
if self.enable_trace: trace_header = (
self._propagate_trace_context_to_sglang(context) self._get_trace_header(context) if self.enable_trace else None
)
agg = await self.engine.async_generate( agg = await self.engine.async_generate(
**input_param, **input_param,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
) )
if self.skip_tokenizer_init: if self.skip_tokenizer_init:
......
...@@ -113,9 +113,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -113,9 +113,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
input_param = self._get_input_param(inner_request) input_param = self._get_input_param(inner_request)
# Propagate trace context to SGLang trace_header = self._get_trace_header(context) if self.enable_trace else None
if self.enable_trace:
self._propagate_trace_context_to_sglang(context, bootstrap_room)
results = await self.engine.async_generate( results = await self.engine.async_generate(
**input_param, **input_param,
...@@ -124,6 +122,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -124,6 +122,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
bootstrap_host=self.bootstrap_host, bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port, bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room, bootstrap_room=bootstrap_room,
external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
) )
......
...@@ -8,6 +8,7 @@ use futures::StreamExt; ...@@ -8,6 +8,7 @@ use futures::StreamExt;
use rand::Rng; use rand::Rng;
use tokio::sync::{OwnedSemaphorePermit, oneshot}; use tokio::sync::{OwnedSemaphorePermit, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use dynamo_runtime::{ use dynamo_runtime::{
component::Endpoint, component::Endpoint,
...@@ -265,10 +266,14 @@ impl PrefillRouter { ...@@ -265,10 +266,14 @@ impl PrefillRouter {
InnerPrefillRouter::KvRouter(r) => r, InnerPrefillRouter::KvRouter(r) => r,
_ => return None, _ => return None,
}; };
match kv_router match async {
kv_router
.chooser .chooser
.find_best_match(None, &req.token_ids, None, false) .find_best_match(None, &req.token_ids, None, false)
.await .await
}
.instrument(tracing::info_span!("kv_find_best_match"))
.await
{ {
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank), Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
Err(_) => return None, Err(_) => return None,
...@@ -405,9 +410,17 @@ impl PrefillRouter { ...@@ -405,9 +410,17 @@ impl PrefillRouter {
phase_permit: OwnedSemaphorePermit, phase_permit: OwnedSemaphorePermit,
) { ) {
let router = self.prefill_router.get().cloned(); let router = self.prefill_router.get().cloned();
// Capture current span to propagate trace context to the spawned task
tokio::spawn(async move { let span = tracing::Span::current();
match Self::execute_prefill(router, prefill_request, target_worker, Some(phase_permit))
tokio::spawn(
async move {
match Self::execute_prefill(
router,
prefill_request,
target_worker,
Some(phase_permit),
)
.await .await
{ {
Ok(_) => { Ok(_) => {
...@@ -417,7 +430,9 @@ impl PrefillRouter { ...@@ -417,7 +430,9 @@ impl PrefillRouter {
tracing::warn!("Prefill background task error: {e:?}"); tracing::warn!("Prefill background task error: {e:?}");
} }
} }
}); }
.instrument(span),
);
} }
/// Call the prefill router and extract structured prefill result and worker ID. /// Call the prefill router and extract structured prefill result and worker ID.
...@@ -491,7 +506,8 @@ impl ...@@ -491,7 +506,8 @@ impl
.as_ref() .as_ref()
.and_then(|r| r.prefill_worker_id); .and_then(|r| r.prefill_worker_id);
let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) = self let prefill_result = async {
if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker) .build_bootstrap_info(&prefill_req, preselected_worker)
.await .await
{ {
...@@ -531,7 +547,10 @@ impl ...@@ -531,7 +547,10 @@ impl
self.call_prefill(prefill_context) self.call_prefill(prefill_context)
.await .await
.map(|(result, worker_id)| (Some(result), worker_id, None)) .map(|(result, worker_id)| (Some(result), worker_id, None))
}; }
}
.instrument(tracing::info_span!("prefill_routing"))
.await;
// Abort if cancelled during prefill // Abort if cancelled during prefill
if engine_ctx.is_stopped() || engine_ctx.is_killed() { if engine_ctx.is_stopped() || engine_ctx.is_killed() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment