"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "7860861fbd0c2678e4118d68a26552f27adfc1a6"
Unverified Commit c263a99e authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: propagate OTEL trace context across E/P/D multimodal workers (#7239)

parent 34f13a12
...@@ -21,6 +21,7 @@ from dynamo.common.multimodal.embedding_transfer import ( ...@@ -21,6 +21,7 @@ from dynamo.common.multimodal.embedding_transfer import (
NixlWriteEmbeddingReceiver, NixlWriteEmbeddingReceiver,
) )
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.common.utils.time_section import time_and_log_code_section from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import Client, DistributedRuntime from dynamo.runtime import Client, DistributedRuntime
...@@ -156,7 +157,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -156,7 +157,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# ── Multimodal data loading ────────────────────────────────────── # ── Multimodal data loading ──────────────────────────────────────
async def _load_multimodal_data( async def _load_multimodal_data(
self, image_urls: list[str], request_id: str self, image_urls: list[str], request_id: str, context=None
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Fetch embeddings from encode workers and load into an engine-ready dict. """Fetch embeddings from encode workers and load into an engine-ready dict.
...@@ -174,6 +175,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -174,6 +175,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
model=self.config.model, model=self.config.model,
embeddings_dtype=self.EMBEDDINGS_DTYPE, embeddings_dtype=self.EMBEDDINGS_DTYPE,
cache=self.embedding_cache_manager, cache=self.embedding_cache_manager,
context=context,
) )
# ── Request metadata finalization ──────────────────────────────── # ── Request metadata finalization ────────────────────────────────
...@@ -260,9 +262,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -260,9 +262,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request: vLLMMultimodalRequest, request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any], multi_modal_data: dict[str, Any],
rng_ttft=None, rng_ttft=None,
context=None,
): ):
"""Run prefill and decode on this worker (aggregated mode).""" """Run prefill and decode on this worker (aggregated mode)."""
lora_request = self._resolve_lora_request(request.model) lora_request = self._resolve_lora_request(request.model)
trace_headers = build_trace_headers(context) if context else None
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"], prompt_token_ids=request.engine_prompt["prompt_token_ids"],
...@@ -271,6 +275,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -271,6 +275,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
request_id=request.request_id, request_id=request.request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
) )
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
...@@ -302,6 +307,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -302,6 +307,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request: vLLMMultimodalRequest, request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any], multi_modal_data: dict[str, Any],
rng_ttft=None, rng_ttft=None,
context=None,
): ):
"""Prefill locally, then forward to a remote decode worker.""" """Prefill locally, then forward to a remote decode worker."""
with _nvtx.annotate( with _nvtx.annotate(
...@@ -319,6 +325,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -319,6 +325,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
logger.debug("Prefill request: %s", prefill_only_request) logger.debug("Prefill request: %s", prefill_only_request)
lora_request = self._resolve_lora_request(request.model) lora_request = self._resolve_lora_request(request.model)
trace_headers = build_trace_headers(context) if context else None
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=prefill_only_request.engine_prompt[ prompt_token_ids=prefill_only_request.engine_prompt[
...@@ -329,6 +336,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -329,6 +336,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
sampling_params=prefill_only_request.sampling_params, sampling_params=prefill_only_request.sampling_params,
request_id=prefill_only_request.request_id, request_id=prefill_only_request.request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
) )
# Drain prefill generator (max_tokens=1, expect a single response) # Drain prefill generator (max_tokens=1, expect a single response)
...@@ -382,7 +390,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -382,7 +390,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
async for ( async for (
decode_response decode_response
) in await self.decode_worker_client.round_robin( # type: ignore ) in await self.decode_worker_client.round_robin( # type: ignore
request.model_dump_json() request.model_dump_json(), context=context
): ):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore
yield self._format_engine_output(output, num_output_tokens_so_far) yield self._format_engine_output(output, num_output_tokens_so_far)
...@@ -406,7 +414,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -406,7 +414,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
rng_load = _nvtx.start_range("mm:pd:load_multimodal", color="yellow") rng_load = _nvtx.start_range("mm:pd:load_multimodal", color="yellow")
multi_modal_data = await self._load_multimodal_data( multi_modal_data = await self._load_multimodal_data(
image_urls, request.request_id image_urls, request.request_id, context
) )
_nvtx.end_range(rng_load) _nvtx.end_range(rng_load)
...@@ -415,13 +423,15 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -415,13 +423,15 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if self.enable_disagg and self.decode_worker_client: if self.enable_disagg and self.decode_worker_client:
rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red") rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
async for chunk in self._generate_disagg( async for chunk in self._generate_disagg(
request, multi_modal_data, rng_ttft request, multi_modal_data, rng_ttft, context=context
): ):
yield chunk yield chunk
_nvtx.end_range(rng_disagg) _nvtx.end_range(rng_disagg)
else: else:
rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red") rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red")
async for chunk in self._generate_agg(request, multi_modal_data, rng_ttft): async for chunk in self._generate_agg(
request, multi_modal_data, rng_ttft, context=context
):
yield chunk yield chunk
_nvtx.end_range(rng_agg) _nvtx.end_range(rng_agg)
......
...@@ -7,6 +7,7 @@ from vllm.inputs.data import TokensPrompt ...@@ -7,6 +7,7 @@ from vllm.inputs.data import TokensPrompt
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.common.utils.time_section import time_and_log_code_section from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
...@@ -57,14 +58,14 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -57,14 +58,14 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
async def generate(self, request: vLLMMultimodalRequest, context): async def generate(self, request: vLLMMultimodalRequest, context):
rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue") rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue")
logger.debug(f"Got raw request: {request}") logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
with time_and_log_code_section( with time_and_log_code_section(
f"[DECODE] request: {request.request_id} preprocessing time" f"[DECODE] request: {request.request_id} preprocessing time"
): ):
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.") logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
# For Qwen VL models with mRoPE, we need to pass multi_modal_data containing # For Qwen VL models with mRoPE, we need to pass multi_modal_data containing
...@@ -90,6 +91,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -90,6 +91,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
image_grid_thw, embeddings_shape, request.request_id image_grid_thw, embeddings_shape, request.request_id
) )
lora_request = self._resolve_lora_request(request.model) lora_request = self._resolve_lora_request(request.model)
trace_headers = build_trace_headers(context) if context else None
with time_and_log_code_section( with time_and_log_code_section(
f"[DECODE] request: {request.request_id} generate time" f"[DECODE] request: {request.request_id} generate time"
...@@ -102,6 +104,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -102,6 +104,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
request_id=request.request_id, request_id=request.request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
) )
rng_first = _nvtx.start_range("mm:decode:first_token", color="darkred") rng_first = _nvtx.start_range("mm:decode:first_token", color="darkred")
......
...@@ -140,6 +140,7 @@ async def _fetch_from_encode_workers( ...@@ -140,6 +140,7 @@ async def _fetch_from_encode_workers(
image_urls: List[str], image_urls: List[str],
request_id: str, request_id: str,
receiver: AbstractEmbeddingReceiver, receiver: AbstractEmbeddingReceiver,
context=None,
) -> tuple[List[MultiModalGroup], _PendingRelease | None]: ) -> tuple[List[MultiModalGroup], _PendingRelease | None]:
"""Fan out image URLs to encode workers, load embeddings, and return ready groups. """Fan out image URLs to encode workers, load embeddings, and return ready groups.
...@@ -176,7 +177,7 @@ async def _fetch_from_encode_workers( ...@@ -176,7 +177,7 @@ async def _fetch_from_encode_workers(
encode_request.multimodal_inputs = batch encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json() payload = encode_request.model_dump_json()
encode_response_streams.append( encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type] await encode_worker_client.round_robin(payload, context=context) # type: ignore[arg-type]
) )
batch = [] batch = []
...@@ -184,7 +185,7 @@ async def _fetch_from_encode_workers( ...@@ -184,7 +185,7 @@ async def _fetch_from_encode_workers(
encode_request.multimodal_inputs = batch encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json() payload = encode_request.model_dump_json()
encode_response_streams.append( encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type] await encode_worker_client.round_robin(payload, context=context) # type: ignore[arg-type]
) )
with time_and_log_code_section( with time_and_log_code_section(
...@@ -223,6 +224,7 @@ async def _fetch_embeddings( ...@@ -223,6 +224,7 @@ async def _fetch_embeddings(
request_id: str, request_id: str,
receiver: AbstractEmbeddingReceiver, receiver: AbstractEmbeddingReceiver,
cache: MultimodalEmbeddingCacheManager | None = None, cache: MultimodalEmbeddingCacheManager | None = None,
context=None,
) -> tuple[list[MultiModalGroup], _PendingRelease | None]: ) -> tuple[list[MultiModalGroup], _PendingRelease | None]:
"""Fetch multimodal embeddings with transparent cache-through. """Fetch multimodal embeddings with transparent cache-through.
...@@ -262,6 +264,7 @@ async def _fetch_embeddings( ...@@ -262,6 +264,7 @@ async def _fetch_embeddings(
miss_urls, miss_urls,
request_id, request_id,
receiver, receiver,
context=context,
) )
# ── 3. Update cache (no-op when cache is None) ────────────── # ── 3. Update cache (no-op when cache is None) ──────────────
...@@ -293,6 +296,7 @@ async def load_multimodal_embeddings( ...@@ -293,6 +296,7 @@ async def load_multimodal_embeddings(
model: str, model: str,
embeddings_dtype: torch.dtype, embeddings_dtype: torch.dtype,
cache: MultimodalEmbeddingCacheManager | None = None, cache: MultimodalEmbeddingCacheManager | None = None,
context=None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``. """Fetch embeddings and build engine-ready ``multi_modal_data``.
...@@ -307,6 +311,7 @@ async def load_multimodal_embeddings( ...@@ -307,6 +311,7 @@ async def load_multimodal_embeddings(
request_id, request_id,
receiver, receiver,
cache=cache, cache=cache,
context=context,
) )
multi_modal_data: Dict[str, Any] = defaultdict(list) multi_modal_data: Dict[str, Any] = defaultdict(list)
......
...@@ -299,7 +299,7 @@ class TestGenerateDisagg: ...@@ -299,7 +299,7 @@ class TestGenerateDisagg:
decode_resp = MagicMock() decode_resp = MagicMock()
decode_resp.data.return_value = decode_json decode_resp.data.return_value = decode_json
async def fake_round_robin(payload): async def fake_round_robin(payload, context=None):
async def _stream(): async def _stream():
yield decode_resp yield decode_resp
......
...@@ -209,19 +209,35 @@ class Client: ...@@ -209,19 +209,35 @@ class Client:
""" """
... ...
async def random(self, request: JsonLike) -> AsyncIterator[JsonLike]: async def random(
self,
request: JsonLike,
annotated: bool | None = True,
context: Context | None = None,
) -> AsyncIterator[JsonLike]:
""" """
Pick a random instance of the endpoint and issue the request Pick a random instance of the endpoint and issue the request
""" """
... ...
async def round_robin(self, request: JsonLike) -> AsyncIterator[JsonLike]: async def round_robin(
self,
request: JsonLike,
annotated: bool | None = True,
context: Context | None = None,
) -> AsyncIterator[JsonLike]:
""" """
Pick the next instance of the endpoint in a round-robin fashion Pick the next instance of the endpoint in a round-robin fashion
""" """
... ...
async def direct(self, request: JsonLike, instance: str) -> AsyncIterator[JsonLike]: async def direct(
self,
request: JsonLike,
instance_id: int,
annotated: bool | None = True,
context: Context | None = None,
) -> AsyncIterator[JsonLike]:
""" """
Pick a specific instance of the endpoint Pick a specific instance of the endpoint
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment