Unverified Commit 34ce777d authored by Jie Hao's avatar Jie Hao Committed by GitHub
Browse files

feat: propagate otel tracing to trtllm E/P/D workers (#7592)

parent b78ec99a
...@@ -29,6 +29,7 @@ async def fetch_embeddings_from_encoder( ...@@ -29,6 +29,7 @@ async def fetch_embeddings_from_encoder(
request: Dict[str, Any], request: Dict[str, Any],
encode_client: Any, encode_client: Any,
encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None, encoder_cache: Optional[MultimodalEmbeddingCacheManager] = None,
trace_context=None,
) -> Union[List[torch.Tensor], DisaggregatedParams]: ) -> Union[List[torch.Tensor], DisaggregatedParams]:
""" """
Fetch embeddings from remote encode worker. Fetch embeddings from remote encode worker.
...@@ -38,6 +39,7 @@ async def fetch_embeddings_from_encoder( ...@@ -38,6 +39,7 @@ async def fetch_embeddings_from_encoder(
request: Request dict (used for creating modified requests for caching) request: Request dict (used for creating modified requests for caching)
encode_client: Client to call remote encode worker encode_client: Client to call remote encode worker
encoder_cache: Optional cache for embeddings encoder_cache: Optional cache for embeddings
trace_context: Optional Dynamo context for OTel trace propagation
Returns: Returns:
- List[torch.Tensor]: When using cache (CPU tensors from cache) - List[torch.Tensor]: When using cache (CPU tensors from cache)
...@@ -56,13 +58,19 @@ async def fetch_embeddings_from_encoder( ...@@ -56,13 +58,19 @@ async def fetch_embeddings_from_encoder(
request, request,
encoder_cache, encoder_cache,
lambda req: _remote_encode_full_epd( lambda req: _remote_encode_full_epd(
req, encode_client, update_request_for_decode=False req,
encode_client,
update_request_for_decode=False,
trace_context=trace_context,
), ),
) )
else: else:
# No cache: return DisaggregatedParams directly (no GPU→CPU extraction) # No cache: return DisaggregatedParams directly (no GPU→CPU extraction)
return await _remote_encode_full_epd( return await _remote_encode_full_epd(
request, encode_client, update_request_for_decode=True request,
encode_client,
update_request_for_decode=True,
trace_context=trace_context,
) )
...@@ -70,6 +78,7 @@ async def _remote_encode_full_epd( ...@@ -70,6 +78,7 @@ async def _remote_encode_full_epd(
request: Dict[str, Any], request: Dict[str, Any],
encode_client: Any, encode_client: Any,
update_request_for_decode: bool = True, update_request_for_decode: bool = True,
trace_context=None,
) -> DisaggregatedParams: ) -> DisaggregatedParams:
""" """
Call encode worker for full EPD flow. Call encode worker for full EPD flow.
...@@ -78,6 +87,7 @@ async def _remote_encode_full_epd( ...@@ -78,6 +87,7 @@ async def _remote_encode_full_epd(
request: Request dict request: Request dict
encode_client: Client to call remote encode worker encode_client: Client to call remote encode worker
update_request_for_decode: If True, store EPD metadata in request update_request_for_decode: If True, store EPD metadata in request
trace_context: Optional Dynamo context for OTel trace propagation
Returns: Returns:
DisaggregatedParams with multimodal_embedding_handles DisaggregatedParams with multimodal_embedding_handles
...@@ -86,7 +96,7 @@ async def _remote_encode_full_epd( ...@@ -86,7 +96,7 @@ async def _remote_encode_full_epd(
RuntimeError: If encode worker returns invalid response RuntimeError: If encode worker returns invalid response
""" """
encode_response = None encode_response = None
async for res in await encode_client.round_robin(request): async for res in await encode_client.round_robin(request, context=trace_context):
encode_response = res.data() encode_response = res.data()
break break
......
...@@ -57,6 +57,7 @@ class AggregatedHandler(HandlerBase): ...@@ -57,6 +57,7 @@ class AggregatedHandler(HandlerBase):
request, request,
self.encode_client, self.encode_client,
self._encoder_cache, self._encoder_cache,
trace_context=context,
) )
if isinstance(result, list): if isinstance(result, list):
embeddings = result # type: ignore[assignment] embeddings = result # type: ignore[assignment]
......
...@@ -100,12 +100,13 @@ class PrefillHandler(HandlerBase): ...@@ -100,12 +100,13 @@ class PrefillHandler(HandlerBase):
super().__init__(config) super().__init__(config)
self._encoder_cache = encoder_cache self._encoder_cache = encoder_cache
async def remote_encode_with_nixl(self, request: dict): async def remote_encode_with_nixl(self, request: dict, context=None):
""" """
Call encode worker for NIXL flow to load embeddings and unpack the response. Call encode worker for NIXL flow to load embeddings and unpack the response.
Args: Args:
request: Request dict request: Request dict
context: Optional Dynamo context for trace propagation
Returns: Returns:
Encoder's embeddings tensor to be used by the prefill worker Encoder's embeddings tensor to be used by the prefill worker
...@@ -114,7 +115,7 @@ class PrefillHandler(HandlerBase): ...@@ -114,7 +115,7 @@ class PrefillHandler(HandlerBase):
if self.encode_client is None: if self.encode_client is None:
raise RuntimeError("Encode client is not configured.") raise RuntimeError("Encode client is not configured.")
encode_response = None encode_response = None
async for res in await self.encode_client.round_robin(request): async for res in await self.encode_client.round_robin(request, context=context):
encode_response = res.data() encode_response = res.data()
break break
...@@ -154,7 +155,9 @@ class PrefillHandler(HandlerBase): ...@@ -154,7 +155,9 @@ class PrefillHandler(HandlerBase):
if embedding_paths: if embedding_paths:
if self.encode_client and self.connector: if self.encode_client and self.connector:
logging.info(f"PrefillHandler: embedding_paths={embedding_paths}") logging.info(f"PrefillHandler: embedding_paths={embedding_paths}")
embeddings_tensor = await self.remote_encode_with_nixl(request) embeddings_tensor = await self.remote_encode_with_nixl(
request, context=context
)
else: else:
# We can still handle embedding_paths without NIXL: # We can still handle embedding_paths without NIXL:
# `MultimodalRequestProcessor.process_openai_request` will load the embeddings # `MultimodalRequestProcessor.process_openai_request` will load the embeddings
...@@ -172,6 +175,7 @@ class PrefillHandler(HandlerBase): ...@@ -172,6 +175,7 @@ class PrefillHandler(HandlerBase):
request, request,
self.encode_client, self.encode_client,
self._encoder_cache, self._encoder_cache,
trace_context=context,
) )
if isinstance(result, list): if isinstance(result, list):
# Cache path: got List[torch.Tensor] # Cache path: got List[torch.Tensor]
......
...@@ -51,7 +51,7 @@ def create_mock_encode_client( ...@@ -51,7 +51,7 @@ def create_mock_encode_client(
"prompt_token_ids": prompt_token_ids or [1, 2, 3], "prompt_token_ids": prompt_token_ids or [1, 2, 3],
} }
async def mock_round_robin(req: dict[str, Any]) -> Any: async def mock_round_robin(req: dict[str, Any], context=None) -> Any:
async def gen(): async def gen():
yield MockResponse() yield MockResponse()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment