Unverified Commit 34ce777d authored by Jie Hao's avatar Jie Hao Committed by GitHub
Browse files

feat: propagate otel tracing to trtllm E/P/D workers (#7592)

parent b78ec99a
......@@ -29,6 +29,7 @@ async def fetch_embeddings_from_encoder(
request: Dict[str, Any],
encode_client: Any,
encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None,
trace_context=None,
) -> Union[List[torch.Tensor], DisaggregatedParams]:
"""
Fetch embeddings from remote encode worker.
......@@ -38,6 +39,7 @@ async def fetch_embeddings_from_encoder(
request: Request dict (used for creating modified requests for caching)
encode_client: Client to call remote encode worker
encoder_cache: Optional cache for embeddings
trace_context: Optional Dynamo context for OTel trace propagation
Returns:
- List[torch.Tensor]: When using cache (CPU tensors from cache)
......@@ -56,13 +58,19 @@ async def fetch_embeddings_from_encoder(
request,
encoder_cache,
lambda req: _remote_encode_full_epd(
req, encode_client, update_request_for_decode=False
req,
encode_client,
update_request_for_decode=False,
trace_context=trace_context,
),
)
else:
# No cache: return DisaggregatedParams directly (no GPU→CPU extraction)
return await _remote_encode_full_epd(
request, encode_client, update_request_for_decode=True
request,
encode_client,
update_request_for_decode=True,
trace_context=trace_context,
)
......@@ -70,6 +78,7 @@ async def _remote_encode_full_epd(
request: Dict[str, Any],
encode_client: Any,
update_request_for_decode: bool = True,
trace_context=None,
) -> DisaggregatedParams:
"""
Call encode worker for full EPD flow.
......@@ -78,6 +87,7 @@ async def _remote_encode_full_epd(
request: Request dict
encode_client: Client to call remote encode worker
update_request_for_decode: If True, store EPD metadata in request
trace_context: Optional Dynamo context for OTel trace propagation
Returns:
DisaggregatedParams with multimodal_embedding_handles
......@@ -86,7 +96,7 @@ async def _remote_encode_full_epd(
RuntimeError: If encode worker returns invalid response
"""
encode_response = None
async for res in await encode_client.round_robin(request):
async for res in await encode_client.round_robin(request, context=trace_context):
encode_response = res.data()
break
......
......@@ -57,6 +57,7 @@ class AggregatedHandler(HandlerBase):
request,
self.encode_client,
self._encoder_cache,
trace_context=context,
)
if isinstance(result, list):
embeddings = result # type: ignore[assignment]
......
......@@ -100,12 +100,13 @@ class PrefillHandler(HandlerBase):
super().__init__(config)
self._encoder_cache = encoder_cache
async def remote_encode_with_nixl(self, request: dict):
async def remote_encode_with_nixl(self, request: dict, context=None):
"""
Call encode worker for NIXL flow to load embeddings and unpack the response.
Args:
request: Request dict
context: Optional Dynamo context for trace propagation
Returns:
Encoder's embeddings tensor to be used by the prefill worker
......@@ -114,7 +115,7 @@ class PrefillHandler(HandlerBase):
if self.encode_client is None:
raise RuntimeError("Encode client is not configured.")
encode_response = None
async for res in await self.encode_client.round_robin(request):
async for res in await self.encode_client.round_robin(request, context=context):
encode_response = res.data()
break
......@@ -154,7 +155,9 @@ class PrefillHandler(HandlerBase):
if embedding_paths:
if self.encode_client and self.connector:
logging.info(f"PrefillHandler: embedding_paths={embedding_paths}")
embeddings_tensor = await self.remote_encode_with_nixl(request)
embeddings_tensor = await self.remote_encode_with_nixl(
request, context=context
)
else:
# We can still handle embedding_paths without NIXL:
# `MultimodalRequestProcessor.process_openai_request` will load the embeddings
......@@ -172,6 +175,7 @@ class PrefillHandler(HandlerBase):
request,
self.encode_client,
self._encoder_cache,
trace_context=context,
)
if isinstance(result, list):
# Cache path: got List[torch.Tensor]
......
......@@ -51,7 +51,7 @@ def create_mock_encode_client(
"prompt_token_ids": prompt_token_ids or [1, 2, 3],
}
async def mock_round_robin(req: dict[str, Any]) -> Any:
async def mock_round_robin(req: dict[str, Any], context=None) -> Any:
async def gen():
yield MockResponse()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment