"vscode:/vscode.git/clone" did not exist on "9ad38639ef380e7f147cfaac77a71c013224fb98"
Unverified Commit c263a99e authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: propagate OTEL trace context across E/P/D multimodal workers (#7239)

parent 34f13a12
......@@ -21,6 +21,7 @@ from dynamo.common.multimodal.embedding_transfer import (
NixlWriteEmbeddingReceiver,
)
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import Client, DistributedRuntime
......@@ -156,7 +157,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# ── Multimodal data loading ──────────────────────────────────────
async def _load_multimodal_data(
self, image_urls: list[str], request_id: str
self, image_urls: list[str], request_id: str, context=None
) -> dict[str, Any]:
"""Fetch embeddings from encode workers and load into an engine-ready dict.
......@@ -174,6 +175,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
model=self.config.model,
embeddings_dtype=self.EMBEDDINGS_DTYPE,
cache=self.embedding_cache_manager,
context=context,
)
# ── Request metadata finalization ────────────────────────────────
......@@ -260,9 +262,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any],
rng_ttft=None,
context=None,
):
"""Run prefill and decode on this worker (aggregated mode)."""
lora_request = self._resolve_lora_request(request.model)
trace_headers = build_trace_headers(context) if context else None
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
......@@ -271,6 +275,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
sampling_params=request.sampling_params,
request_id=request.request_id,
lora_request=lora_request,
trace_headers=trace_headers,
)
num_output_tokens_so_far = 0
......@@ -302,6 +307,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any],
rng_ttft=None,
context=None,
):
"""Prefill locally, then forward to a remote decode worker."""
with _nvtx.annotate(
......@@ -319,6 +325,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
logger.debug("Prefill request: %s", prefill_only_request)
lora_request = self._resolve_lora_request(request.model)
trace_headers = build_trace_headers(context) if context else None
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=prefill_only_request.engine_prompt[
......@@ -329,6 +336,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
sampling_params=prefill_only_request.sampling_params,
request_id=prefill_only_request.request_id,
lora_request=lora_request,
trace_headers=trace_headers,
)
# Drain prefill generator (max_tokens=1, expect a single response)
......@@ -382,7 +390,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
async for (
decode_response
) in await self.decode_worker_client.round_robin( # type: ignore
request.model_dump_json()
request.model_dump_json(), context=context
):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore
yield self._format_engine_output(output, num_output_tokens_so_far)
......@@ -406,7 +414,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
rng_load = _nvtx.start_range("mm:pd:load_multimodal", color="yellow")
multi_modal_data = await self._load_multimodal_data(
image_urls, request.request_id
image_urls, request.request_id, context
)
_nvtx.end_range(rng_load)
......@@ -415,13 +423,15 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if self.enable_disagg and self.decode_worker_client:
rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
async for chunk in self._generate_disagg(
request, multi_modal_data, rng_ttft
request, multi_modal_data, rng_ttft, context=context
):
yield chunk
_nvtx.end_range(rng_disagg)
else:
rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red")
async for chunk in self._generate_agg(request, multi_modal_data, rng_ttft):
async for chunk in self._generate_agg(
request, multi_modal_data, rng_ttft, context=context
):
yield chunk
_nvtx.end_range(rng_agg)
......
......@@ -7,6 +7,7 @@ from vllm.inputs.data import TokensPrompt
import dynamo.nixl_connect as connect
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime
......@@ -57,14 +58,14 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
async def generate(self, request: vLLMMultimodalRequest, context):
rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue")
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
with time_and_log_code_section(
f"[DECODE] request: {request.request_id} preprocessing time"
):
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
# For Qwen VL models with mRoPE, we need to pass multi_modal_data containing
......@@ -90,6 +91,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
image_grid_thw, embeddings_shape, request.request_id
)
lora_request = self._resolve_lora_request(request.model)
trace_headers = build_trace_headers(context) if context else None
with time_and_log_code_section(
f"[DECODE] request: {request.request_id} generate time"
......@@ -102,6 +104,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
sampling_params=request.sampling_params,
request_id=request.request_id,
lora_request=lora_request,
trace_headers=trace_headers,
)
rng_first = _nvtx.start_range("mm:decode:first_token", color="darkred")
......
......@@ -140,6 +140,7 @@ async def _fetch_from_encode_workers(
image_urls: List[str],
request_id: str,
receiver: AbstractEmbeddingReceiver,
context=None,
) -> tuple[List[MultiModalGroup], _PendingRelease | None]:
"""Fan out image URLs to encode workers, load embeddings, and return ready groups.
......@@ -176,7 +177,7 @@ async def _fetch_from_encode_workers(
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
await encode_worker_client.round_robin(payload, context=context) # type: ignore[arg-type]
)
batch = []
......@@ -184,7 +185,7 @@ async def _fetch_from_encode_workers(
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
await encode_worker_client.round_robin(payload, context=context) # type: ignore[arg-type]
)
with time_and_log_code_section(
......@@ -223,6 +224,7 @@ async def _fetch_embeddings(
request_id: str,
receiver: AbstractEmbeddingReceiver,
cache: MultimodalEmbeddingCacheManager | None = None,
context=None,
) -> tuple[list[MultiModalGroup], _PendingRelease | None]:
"""Fetch multimodal embeddings with transparent cache-through.
......@@ -262,6 +264,7 @@ async def _fetch_embeddings(
miss_urls,
request_id,
receiver,
context=context,
)
# ── 3. Update cache (no-op when cache is None) ──────────────
......@@ -293,6 +296,7 @@ async def load_multimodal_embeddings(
model: str,
embeddings_dtype: torch.dtype,
cache: MultimodalEmbeddingCacheManager | None = None,
context=None,
) -> Dict[str, Any]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``.
......@@ -307,6 +311,7 @@ async def load_multimodal_embeddings(
request_id,
receiver,
cache=cache,
context=context,
)
multi_modal_data: Dict[str, Any] = defaultdict(list)
......
......@@ -299,7 +299,7 @@ class TestGenerateDisagg:
decode_resp = MagicMock()
decode_resp.data.return_value = decode_json
async def fake_round_robin(payload):
async def fake_round_robin(payload, context=None):
async def _stream():
yield decode_resp
......
......@@ -209,19 +209,35 @@ class Client:
"""
...
async def random(self, request: JsonLike) -> AsyncIterator[JsonLike]:
async def random(
self,
request: JsonLike,
annotated: bool | None = True,
context: Context | None = None,
) -> AsyncIterator[JsonLike]:
"""
Pick a random instance of the endpoint and issue the request
"""
...
async def round_robin(self, request: JsonLike) -> AsyncIterator[JsonLike]:
async def round_robin(
self,
request: JsonLike,
annotated: bool | None = True,
context: Context | None = None,
) -> AsyncIterator[JsonLike]:
"""
Pick the next instance of the endpoint in a round-robin fashion
"""
...
async def direct(self, request: JsonLike, instance: str) -> AsyncIterator[JsonLike]:
async def direct(
self,
request: JsonLike,
instance_id: int,
annotated: bool | None = True,
context: Context | None = None,
) -> AsyncIterator[JsonLike]:
"""
Pick a specific instance of the endpoint
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment