Unverified Commit 8886d184 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: propagate OTEL trace context across SGLang multimodal E/P/D workers (#7580)

parent e5e65b58
......@@ -9,6 +9,7 @@ from typing import Any, Dict, Optional
import sglang as sgl
from dynamo._core import Context
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import EmbeddingRequest
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -55,7 +56,14 @@ class EmbeddingWorkerHandler(BaseWorkerHandler):
else:
raise TypeError(f"Invalid input type: {type(embedding_request.input)}")
result = await self.engine.async_encode(prompt=prompt)
trace_header = build_trace_headers(context) if self.enable_trace else None
trace_id = context.trace_id
result = await self.engine.async_encode(
prompt=prompt,
external_trace_header=trace_header,
rid=trace_id,
)
# Transform the response to OpenAI format
response = self._transform_response(result, embedding_request.model)
......
......@@ -107,6 +107,7 @@ class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]):
publisher: Optional metrics publisher for the worker.
"""
self.config = config
self.enable_trace = config.server_args.enable_trace
# Set up metrics and KV publishers
self.metrics_publisher: Optional[WorkerMetricsPublisher] = None
......@@ -132,21 +133,6 @@ class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]):
"""Cleanup resources. Override in subclasses as needed."""
pass
def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]:
"""Get trace header dict for passing to generation functions.
Args:
context: Dynamo Context object containing trace information.
Returns:
Dict with traceparent header if trace context available, None otherwise.
"""
trace_id = context.trace_id
span_id = context.span_id
if not trace_id or not span_id:
return None
return {"traceparent": f"00-{trace_id}-{span_id}-01"}
class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
"""Abstract base class for SGLang LLM worker handlers.
......
......@@ -15,6 +15,7 @@ from PIL import Image
from dynamo._core import Context
from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -84,7 +85,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
logger.info(f"Image diffusion request: {request}")
# Get trace header for distributed tracing (for logging/observability)
trace_header = self._get_trace_header(context)
trace_header = build_trace_headers(context) if self.enable_trace else None
if trace_header:
logger.debug(f"Image diffusion request with trace: {trace_header}")
......
......@@ -12,6 +12,7 @@ import sglang as sgl
from dynamo._core import Context
from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......@@ -129,9 +130,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
f"room={bootstrap_info['bootstrap_room']}"
)
trace_header = (
self._get_trace_header(context) if self.enable_trace else None
)
trace_header = build_trace_headers(context) if self.enable_trace else None
# Extract dp_rank from routing info (set by KV router)
routing = request.get("routing") or {}
......@@ -171,9 +170,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
image_data.append(item["Url"])
image_data = image_data or None
trace_header = (
self._get_trace_header(context) if self.enable_trace else None
)
trace_header = build_trace_headers(context) if self.enable_trace else None
# Extract dp_rank from routing info (set by KV router)
routing = request.get("routing") or {}
......
......@@ -7,6 +7,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
from dynamo._core import Context
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.llm.decode_handler import DecodeWorkerHandler
......@@ -76,7 +77,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
sampling_params = self._build_sampling_params(request)
# Generate trace info if tracing is enabled
trace_header = self._get_trace_header(context) if self.enable_trace else None
trace_header = build_trace_headers(context) if self.enable_trace else None
trace_id = context.id() if trace_header else None
async_gen = await self.engine.async_generate(
......
......@@ -8,6 +8,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
from dynamo._core import Context
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......@@ -140,7 +141,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
if dp_rank is not None and dp_rank == _DP_RANK_UNSET:
dp_rank = None
trace_header = self._get_trace_header(context) if self.enable_trace else None
trace_header = build_trace_headers(context) if self.enable_trace else None
results = await self.engine.async_generate(
**input_param,
......
......@@ -315,7 +315,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
# Get the response generator from downstream worker
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
request.model_dump_json(), context=context
)
# Parse PD worker responses and yield as LLMEngineOutput-
......
......@@ -14,6 +14,7 @@ from dynamo.common.constants import DisaggregationMode, EmbeddingTransferMode
from dynamo.common.multimodal import EMBEDDING_RECEIVER_FACTORIES, TransferRequest
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import (
DisaggSglangMultimodalRequest,
......@@ -328,7 +329,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
try:
async for output in self._generate_disaggregated(
request, _end_ttft
request, _end_ttft, context=context
):
yield output
finally:
......@@ -336,7 +337,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
else:
rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red")
try:
async for output in self._generate_aggregated(request, _end_ttft):
async for output in self._generate_aggregated(
request, _end_ttft, context=context
):
yield output
finally:
_nvtx.end_range(rng_agg)
......@@ -352,6 +355,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
self,
request: SglangMultimodalRequest,
end_ttft: Callable[[], None],
context=None,
) -> AsyncIterator[str]:
"""Handle disaggregated mode generation"""
input_ids = request.request.token_ids
......@@ -362,7 +366,11 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
# Request bootstrap info from prefill worker
bootstrap_info = await self._get_bootstrap_from_prefill(
request, sampling_params
request, sampling_params, context=context
)
trace_header = (
build_trace_headers(context) if context and self.enable_trace else None
)
# Start decode generation with bootstrap info (no image data needed)
......@@ -373,6 +381,8 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
bootstrap_host=bootstrap_info["bootstrap_host"],
bootstrap_port=bootstrap_info["bootstrap_port"],
bootstrap_room=bootstrap_info["bootstrap_room"],
external_trace_header=trace_header,
rid=context.trace_id if context else None,
)
rng_first = _nvtx.start_range("mm:dec:first_token", color="purple")
......@@ -393,6 +403,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
self,
request: SglangMultimodalRequest,
end_ttft: Callable[[], None],
context=None,
) -> AsyncIterator[str]:
"""Handle aggregated mode generation"""
input_ids = request.request.token_ids
......@@ -412,11 +423,17 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
)
logger.debug(f"Input token sequence length: {len(input_ids)}")
trace_header = (
build_trace_headers(context) if context and self.enable_trace else None
)
agg_stream = await self.engine.async_generate(
input_ids=input_ids,
image_data=mm_items,
sampling_params=sampling_params,
stream=True,
external_trace_header=trace_header,
rid=context.trace_id if context else None,
)
rng_first = _nvtx.start_range("mm:dec:first_token", color="purple")
......@@ -459,7 +476,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
self.embeddings_processor.release_embeddings(tensor_id)
async def _get_bootstrap_from_prefill(
self, request: SglangMultimodalRequest, sampling_params: dict
self, request: SglangMultimodalRequest, sampling_params: dict, context=None
) -> dict:
"""Get bootstrap info from prefill worker"""
assert self.prefill_client is not None
......@@ -467,7 +484,8 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
DisaggSglangMultimodalRequest(
request=request,
sampling_params=sampling_params,
).model_dump_json()
).model_dump_json(),
context=context,
)
bootstrap_info = None
......@@ -554,7 +572,9 @@ class MultimodalPrefillWorkerHandler(
yield json.dumps(bootstrap_info)
# Process prefill generation
await self._process_prefill_generation(disagg_request, bootstrap_room)
await self._process_prefill_generation(
disagg_request, bootstrap_room, context=context
)
except Exception as e:
logger.error(f"Error in prefill generation: {e}", exc_info=True)
......@@ -581,7 +601,10 @@ class MultimodalPrefillWorkerHandler(
return disagg_request
async def _process_prefill_generation(
self, disagg_request: DisaggSglangMultimodalRequest, bootstrap_room: int
self,
disagg_request: DisaggSglangMultimodalRequest,
bootstrap_room: int,
context=None,
):
"""Process multimodal input and start prefill generation"""
# Get the SglangMultimodalRequest from the DisaggSglangMultimodalRequest
......@@ -596,6 +619,10 @@ class MultimodalPrefillWorkerHandler(
request, self.embeddings_processor
)
trace_header = (
build_trace_headers(context) if context and self.enable_trace else None
)
# Start SGLang prefill generation (like regular SGLang)
with _nvtx.annotate("mm:prefill:engine_async_generate", color="blue"):
results = await self.engine.async_generate(
......@@ -606,6 +633,8 @@ class MultimodalPrefillWorkerHandler(
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room,
external_trace_header=trace_header,
rid=context.trace_id if context else None,
)
# Consume results without yielding (prefill doesn't return text, just coordinates)
......
......@@ -13,6 +13,7 @@ import torch
from dynamo._core import Context
from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import (
CreateVideoRequest,
......@@ -89,7 +90,7 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
start_time = time.time()
# Get trace header for distributed tracing (for logging/observability)
trace_header = self._get_trace_header(context)
trace_header = build_trace_headers(context) if self.enable_trace else None
if trace_header:
logger.debug(f"Video generation request with trace: {trace_header}")
......
......@@ -348,9 +348,6 @@ class TestImageDiffusionWorkerHandler:
test_image = Image.new("RGB", (256, 256), color="yellow")
handler._generate_images = Mock(return_value=[test_image.tobytes()])
handler._get_trace_header = Mock(
return_value={"traceparent": "00-1234567890-1234567890-01"}
)
request = {
"prompt": "A yellow square",
......@@ -368,6 +365,11 @@ class TestImageDiffusionWorkerHandler:
# Execute generation
results = []
trace_patch = patch(
"dynamo.sglang.request_handlers.image_diffusion.image_diffusion_handler.build_trace_headers",
return_value={"traceparent": "00-1234567890-1234567890-01"},
)
with trace_patch:
async for result in handler.generate(request, mock_context):
results.append(result)
......
......@@ -137,7 +137,7 @@ Dynamo RPC (NATS transport)
|
v
SGLang Handler (Python)
handler_base.py:_get_trace_header(context)
dynamo.common.utils.otel_tracing.build_trace_headers(context)
builds W3C traceparent: "00-{trace_id}-{span_id}-01"
|
v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment