Unverified Commit 8886d184 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: propagate OTEL trace context across SGLang multimodal E/P/D workers (#7580)

parent e5e65b58
...@@ -9,6 +9,7 @@ from typing import Any, Dict, Optional ...@@ -9,6 +9,7 @@ from typing import Any, Dict, Optional
import sglang as sgl import sglang as sgl
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import EmbeddingRequest from dynamo.sglang.protocol import EmbeddingRequest
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -55,7 +56,14 @@ class EmbeddingWorkerHandler(BaseWorkerHandler): ...@@ -55,7 +56,14 @@ class EmbeddingWorkerHandler(BaseWorkerHandler):
else: else:
raise TypeError(f"Invalid input type: {type(embedding_request.input)}") raise TypeError(f"Invalid input type: {type(embedding_request.input)}")
result = await self.engine.async_encode(prompt=prompt) trace_header = build_trace_headers(context) if self.enable_trace else None
trace_id = context.trace_id
result = await self.engine.async_encode(
prompt=prompt,
external_trace_header=trace_header,
rid=trace_id,
)
# Transform the response to OpenAI format # Transform the response to OpenAI format
response = self._transform_response(result, embedding_request.model) response = self._transform_response(result, embedding_request.model)
......
...@@ -107,6 +107,7 @@ class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -107,6 +107,7 @@ class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]):
publisher: Optional metrics publisher for the worker. publisher: Optional metrics publisher for the worker.
""" """
self.config = config self.config = config
self.enable_trace = config.server_args.enable_trace
# Set up metrics and KV publishers # Set up metrics and KV publishers
self.metrics_publisher: Optional[WorkerMetricsPublisher] = None self.metrics_publisher: Optional[WorkerMetricsPublisher] = None
...@@ -132,21 +133,6 @@ class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -132,21 +133,6 @@ class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]):
"""Cleanup resources. Override in subclasses as needed.""" """Cleanup resources. Override in subclasses as needed."""
pass pass
def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]:
"""Get trace header dict for passing to generation functions.
Args:
context: Dynamo Context object containing trace information.
Returns:
Dict with traceparent header if trace context available, None otherwise.
"""
trace_id = context.trace_id
span_id = context.span_id
if not trace_id or not span_id:
return None
return {"traceparent": f"00-{trace_id}-{span_id}-01"}
class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]): class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
"""Abstract base class for SGLang LLM worker handlers. """Abstract base class for SGLang LLM worker handlers.
......
...@@ -15,6 +15,7 @@ from PIL import Image ...@@ -15,6 +15,7 @@ from PIL import Image
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.storage import upload_to_fs from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -84,7 +85,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler): ...@@ -84,7 +85,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
logger.info(f"Image diffusion request: {request}") logger.info(f"Image diffusion request: {request}")
# Get trace header for distributed tracing (for logging/observability) # Get trace header for distributed tracing (for logging/observability)
trace_header = self._get_trace_header(context) trace_header = build_trace_headers(context) if self.enable_trace else None
if trace_header: if trace_header:
logger.debug(f"Image diffusion request with trace: {trace_header}") logger.debug(f"Image diffusion request with trace: {trace_header}")
......
...@@ -12,6 +12,7 @@ import sglang as sgl ...@@ -12,6 +12,7 @@ import sglang as sgl
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.constants import DisaggregationMode from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
...@@ -129,9 +130,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -129,9 +130,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
f"room={bootstrap_info['bootstrap_room']}" f"room={bootstrap_info['bootstrap_room']}"
) )
trace_header = ( trace_header = build_trace_headers(context) if self.enable_trace else None
self._get_trace_header(context) if self.enable_trace else None
)
# Extract dp_rank from routing info (set by KV router) # Extract dp_rank from routing info (set by KV router)
routing = request.get("routing") or {} routing = request.get("routing") or {}
...@@ -171,9 +170,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -171,9 +170,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
image_data.append(item["Url"]) image_data.append(item["Url"])
image_data = image_data or None image_data = image_data or None
trace_header = ( trace_header = build_trace_headers(context) if self.enable_trace else None
self._get_trace_header(context) if self.enable_trace else None
)
# Extract dp_rank from routing info (set by KV router) # Extract dp_rank from routing info (set by KV router)
routing = request.get("routing") or {} routing = request.get("routing") or {}
......
...@@ -7,6 +7,7 @@ from typing import Any, AsyncGenerator, Dict, Optional ...@@ -7,6 +7,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl import sglang as sgl
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.llm.decode_handler import DecodeWorkerHandler from dynamo.sglang.request_handlers.llm.decode_handler import DecodeWorkerHandler
...@@ -76,7 +77,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler): ...@@ -76,7 +77,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
sampling_params = self._build_sampling_params(request) sampling_params = self._build_sampling_params(request)
# Generate trace info if tracing is enabled # Generate trace info if tracing is enabled
trace_header = self._get_trace_header(context) if self.enable_trace else None trace_header = build_trace_headers(context) if self.enable_trace else None
trace_id = context.id() if trace_header else None trace_id = context.id() if trace_header else None
async_gen = await self.engine.async_generate( async_gen = await self.engine.async_generate(
......
...@@ -8,6 +8,7 @@ from typing import Any, AsyncGenerator, Dict, Optional ...@@ -8,6 +8,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl import sglang as sgl
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
...@@ -140,7 +141,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -140,7 +141,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
if dp_rank is not None and dp_rank == _DP_RANK_UNSET: if dp_rank is not None and dp_rank == _DP_RANK_UNSET:
dp_rank = None dp_rank = None
trace_header = self._get_trace_header(context) if self.enable_trace else None trace_header = build_trace_headers(context) if self.enable_trace else None
results = await self.engine.async_generate( results = await self.engine.async_generate(
**input_param, **input_param,
......
...@@ -315,7 +315,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s ...@@ -315,7 +315,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
# Get the response generator from downstream worker # Get the response generator from downstream worker
response_generator = await self.pd_worker_client.round_robin( response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json() request.model_dump_json(), context=context
) )
# Parse PD worker responses and yield as LLMEngineOutput- # Parse PD worker responses and yield as LLMEngineOutput-
......
...@@ -14,6 +14,7 @@ from dynamo.common.constants import DisaggregationMode, EmbeddingTransferMode ...@@ -14,6 +14,7 @@ from dynamo.common.constants import DisaggregationMode, EmbeddingTransferMode
from dynamo.common.multimodal import EMBEDDING_RECEIVER_FACTORIES, TransferRequest from dynamo.common.multimodal import EMBEDDING_RECEIVER_FACTORIES, TransferRequest
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import ( from dynamo.sglang.protocol import (
DisaggSglangMultimodalRequest, DisaggSglangMultimodalRequest,
...@@ -328,7 +329,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]): ...@@ -328,7 +329,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red") rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
try: try:
async for output in self._generate_disaggregated( async for output in self._generate_disaggregated(
request, _end_ttft request, _end_ttft, context=context
): ):
yield output yield output
finally: finally:
...@@ -336,7 +337,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]): ...@@ -336,7 +337,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
else: else:
rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red") rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red")
try: try:
async for output in self._generate_aggregated(request, _end_ttft): async for output in self._generate_aggregated(
request, _end_ttft, context=context
):
yield output yield output
finally: finally:
_nvtx.end_range(rng_agg) _nvtx.end_range(rng_agg)
...@@ -352,6 +355,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]): ...@@ -352,6 +355,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
self, self,
request: SglangMultimodalRequest, request: SglangMultimodalRequest,
end_ttft: Callable[[], None], end_ttft: Callable[[], None],
context=None,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Handle disaggregated mode generation""" """Handle disaggregated mode generation"""
input_ids = request.request.token_ids input_ids = request.request.token_ids
...@@ -362,7 +366,11 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]): ...@@ -362,7 +366,11 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
# Request bootstrap info from prefill worker # Request bootstrap info from prefill worker
bootstrap_info = await self._get_bootstrap_from_prefill( bootstrap_info = await self._get_bootstrap_from_prefill(
request, sampling_params request, sampling_params, context=context
)
trace_header = (
build_trace_headers(context) if context and self.enable_trace else None
) )
# Start decode generation with bootstrap info (no image data needed) # Start decode generation with bootstrap info (no image data needed)
...@@ -373,6 +381,8 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]): ...@@ -373,6 +381,8 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
bootstrap_host=bootstrap_info["bootstrap_host"], bootstrap_host=bootstrap_info["bootstrap_host"],
bootstrap_port=bootstrap_info["bootstrap_port"], bootstrap_port=bootstrap_info["bootstrap_port"],
bootstrap_room=bootstrap_info["bootstrap_room"], bootstrap_room=bootstrap_info["bootstrap_room"],
external_trace_header=trace_header,
rid=context.trace_id if context else None,
) )
rng_first = _nvtx.start_range("mm:dec:first_token", color="purple") rng_first = _nvtx.start_range("mm:dec:first_token", color="purple")
...@@ -393,6 +403,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]): ...@@ -393,6 +403,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
self, self,
request: SglangMultimodalRequest, request: SglangMultimodalRequest,
end_ttft: Callable[[], None], end_ttft: Callable[[], None],
context=None,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Handle aggregated mode generation""" """Handle aggregated mode generation"""
input_ids = request.request.token_ids input_ids = request.request.token_ids
...@@ -412,11 +423,17 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]): ...@@ -412,11 +423,17 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
) )
logger.debug(f"Input token sequence length: {len(input_ids)}") logger.debug(f"Input token sequence length: {len(input_ids)}")
trace_header = (
build_trace_headers(context) if context and self.enable_trace else None
)
agg_stream = await self.engine.async_generate( agg_stream = await self.engine.async_generate(
input_ids=input_ids, input_ids=input_ids,
image_data=mm_items, image_data=mm_items,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
external_trace_header=trace_header,
rid=context.trace_id if context else None,
) )
rng_first = _nvtx.start_range("mm:dec:first_token", color="purple") rng_first = _nvtx.start_range("mm:dec:first_token", color="purple")
...@@ -459,7 +476,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]): ...@@ -459,7 +476,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
self.embeddings_processor.release_embeddings(tensor_id) self.embeddings_processor.release_embeddings(tensor_id)
async def _get_bootstrap_from_prefill( async def _get_bootstrap_from_prefill(
self, request: SglangMultimodalRequest, sampling_params: dict self, request: SglangMultimodalRequest, sampling_params: dict, context=None
) -> dict: ) -> dict:
"""Get bootstrap info from prefill worker""" """Get bootstrap info from prefill worker"""
assert self.prefill_client is not None assert self.prefill_client is not None
...@@ -467,7 +484,8 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]): ...@@ -467,7 +484,8 @@ class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
DisaggSglangMultimodalRequest( DisaggSglangMultimodalRequest(
request=request, request=request,
sampling_params=sampling_params, sampling_params=sampling_params,
).model_dump_json() ).model_dump_json(),
context=context,
) )
bootstrap_info = None bootstrap_info = None
...@@ -554,7 +572,9 @@ class MultimodalPrefillWorkerHandler( ...@@ -554,7 +572,9 @@ class MultimodalPrefillWorkerHandler(
yield json.dumps(bootstrap_info) yield json.dumps(bootstrap_info)
# Process prefill generation # Process prefill generation
await self._process_prefill_generation(disagg_request, bootstrap_room) await self._process_prefill_generation(
disagg_request, bootstrap_room, context=context
)
except Exception as e: except Exception as e:
logger.error(f"Error in prefill generation: {e}", exc_info=True) logger.error(f"Error in prefill generation: {e}", exc_info=True)
...@@ -581,7 +601,10 @@ class MultimodalPrefillWorkerHandler( ...@@ -581,7 +601,10 @@ class MultimodalPrefillWorkerHandler(
return disagg_request return disagg_request
async def _process_prefill_generation( async def _process_prefill_generation(
self, disagg_request: DisaggSglangMultimodalRequest, bootstrap_room: int self,
disagg_request: DisaggSglangMultimodalRequest,
bootstrap_room: int,
context=None,
): ):
"""Process multimodal input and start prefill generation""" """Process multimodal input and start prefill generation"""
# Get the SglangMultimodalRequest from the DisaggSglangMultimodalRequest # Get the SglangMultimodalRequest from the DisaggSglangMultimodalRequest
...@@ -596,6 +619,10 @@ class MultimodalPrefillWorkerHandler( ...@@ -596,6 +619,10 @@ class MultimodalPrefillWorkerHandler(
request, self.embeddings_processor request, self.embeddings_processor
) )
trace_header = (
build_trace_headers(context) if context and self.enable_trace else None
)
# Start SGLang prefill generation (like regular SGLang) # Start SGLang prefill generation (like regular SGLang)
with _nvtx.annotate("mm:prefill:engine_async_generate", color="blue"): with _nvtx.annotate("mm:prefill:engine_async_generate", color="blue"):
results = await self.engine.async_generate( results = await self.engine.async_generate(
...@@ -606,6 +633,8 @@ class MultimodalPrefillWorkerHandler( ...@@ -606,6 +633,8 @@ class MultimodalPrefillWorkerHandler(
bootstrap_host=self.bootstrap_host, bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port, bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room, bootstrap_room=bootstrap_room,
external_trace_header=trace_header,
rid=context.trace_id if context else None,
) )
# Consume results without yielding (prefill doesn't return text, just coordinates) # Consume results without yielding (prefill doesn't return text, just coordinates)
......
...@@ -13,6 +13,7 @@ import torch ...@@ -13,6 +13,7 @@ import torch
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.storage import upload_to_fs from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import ( from dynamo.sglang.protocol import (
CreateVideoRequest, CreateVideoRequest,
...@@ -89,7 +90,7 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler): ...@@ -89,7 +90,7 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
start_time = time.time() start_time = time.time()
# Get trace header for distributed tracing (for logging/observability) # Get trace header for distributed tracing (for logging/observability)
trace_header = self._get_trace_header(context) trace_header = build_trace_headers(context) if self.enable_trace else None
if trace_header: if trace_header:
logger.debug(f"Video generation request with trace: {trace_header}") logger.debug(f"Video generation request with trace: {trace_header}")
......
...@@ -348,9 +348,6 @@ class TestImageDiffusionWorkerHandler: ...@@ -348,9 +348,6 @@ class TestImageDiffusionWorkerHandler:
test_image = Image.new("RGB", (256, 256), color="yellow") test_image = Image.new("RGB", (256, 256), color="yellow")
handler._generate_images = Mock(return_value=[test_image.tobytes()]) handler._generate_images = Mock(return_value=[test_image.tobytes()])
handler._get_trace_header = Mock(
return_value={"traceparent": "00-1234567890-1234567890-01"}
)
request = { request = {
"prompt": "A yellow square", "prompt": "A yellow square",
...@@ -368,6 +365,11 @@ class TestImageDiffusionWorkerHandler: ...@@ -368,6 +365,11 @@ class TestImageDiffusionWorkerHandler:
# Execute generation # Execute generation
results = [] results = []
trace_patch = patch(
"dynamo.sglang.request_handlers.image_diffusion.image_diffusion_handler.build_trace_headers",
return_value={"traceparent": "00-1234567890-1234567890-01"},
)
with trace_patch:
async for result in handler.generate(request, mock_context): async for result in handler.generate(request, mock_context):
results.append(result) results.append(result)
......
...@@ -137,7 +137,7 @@ Dynamo RPC (NATS transport) ...@@ -137,7 +137,7 @@ Dynamo RPC (NATS transport)
| |
v v
SGLang Handler (Python) SGLang Handler (Python)
handler_base.py:_get_trace_header(context) dynamo.common.utils.otel_tracing.build_trace_headers(context)
builds W3C traceparent: "00-{trace_id}-{span_id}-01" builds W3C traceparent: "00-{trace_id}-{span_id}-01"
| |
v v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment