"lib/ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "5755a8dec365b6d13765db18498a4e8ba76fa377"
Unverified Commit eb0bf24e authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add NVTX markers for vLLM EPD (#6627)

parent e15685d8
...@@ -20,6 +20,7 @@ from pydantic import BaseModel ...@@ -20,6 +20,7 @@ from pydantic import BaseModel
from safetensors import torch as safetensors_torch from safetensors import torch as safetensors_torch
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils import nvtx_utils as _nvtx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -141,6 +142,7 @@ class LocalEmbeddingSender(AbstractEmbeddingSender): ...@@ -141,6 +142,7 @@ class LocalEmbeddingSender(AbstractEmbeddingSender):
) )
return tensor_path return tensor_path
@_nvtx.annotate("mm:local:send_embeddings", color="magenta")
async def send_embeddings( async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]: ) -> tuple[TransferRequest, asyncio.Future]:
...@@ -185,6 +187,7 @@ class LocalEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -185,6 +187,7 @@ class LocalEmbeddingReceiver(AbstractEmbeddingReceiver):
self.received_tensors = {} self.received_tensors = {}
self.tensor_id_counter = 0 self.tensor_id_counter = 0
@_nvtx.annotate("mm:local:receive_embeddings", color="magenta")
async def receive_embeddings( async def receive_embeddings(
self, request: TransferRequest self, request: TransferRequest
) -> tuple[int, torch.Tensor]: ) -> tuple[int, torch.Tensor]:
...@@ -803,6 +806,7 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender): ...@@ -803,6 +806,7 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender):
def __init__(self): def __init__(self):
self.connector = PersistentConnector() self.connector = PersistentConnector()
@_nvtx.annotate("mm:nixl:send_embeddings", color="magenta")
async def send_embeddings( async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]: ) -> tuple[TransferRequest, asyncio.Future]:
...@@ -821,9 +825,10 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender): ...@@ -821,9 +825,10 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender):
transfer_buf = embeddings transfer_buf = embeddings
else: else:
transfer_buf = embeddings.clone().detach() transfer_buf = embeddings.clone().detach()
descriptor = nixl_connect.Descriptor(transfer_buf) with _nvtx.annotate("mm:nixl:create_descriptor", color="pink"):
readable_op = await self.connector.create_readable(descriptor) descriptor = nixl_connect.Descriptor(transfer_buf)
with _nvtx.annotate("mm:nixl:create_readable", color="pink"):
readable_op = await self.connector.create_readable(descriptor)
request = TransferRequest( request = TransferRequest(
embeddings_shape=list(embeddings.shape), embeddings_shape=list(embeddings.shape),
embedding_dtype_str=torch_dtype_to_string(embeddings.dtype), embedding_dtype_str=torch_dtype_to_string(embeddings.dtype),
...@@ -877,6 +882,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -877,6 +882,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
descriptor.register_with_connector(connection) descriptor.register_with_connector(connection)
self.warmedup_descriptors.put(descriptor) self.warmedup_descriptors.put(descriptor)
@_nvtx.annotate("mm:nixl:receive_embeddings", color="magenta")
async def receive_embeddings( async def receive_embeddings(
self, request: TransferRequest self, request: TransferRequest
) -> tuple[int, torch.Tensor]: ) -> tuple[int, torch.Tensor]:
...@@ -918,10 +924,12 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -918,10 +924,12 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
) )
dynamic_descriptor = False dynamic_descriptor = False
# Create read operation to read from EncodeHandler with _nvtx.annotate("mm:nixl:begin_read", color="pink"):
read_op = await self.connector.begin_read(readable_metadata, descriptor) # Create read operation to read from EncodeHandler
# Wait for the read operation to complete read_op = await self.connector.begin_read(readable_metadata, descriptor)
await read_op.wait_for_completion() with _nvtx.annotate("mm:nixl:wait_completion", color="pink"):
# Wait for the read operation to complete
await read_op.wait_for_completion()
logging.debug( logging.debug(
f"Successfully read embeddings via NIXL: {encodings_tensor.shape}" f"Successfully read embeddings via NIXL: {encodings_tensor.shape}"
) )
......
...@@ -25,6 +25,7 @@ import httpx ...@@ -25,6 +25,7 @@ import httpx
from PIL import Image from PIL import Image
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from .http_client import get_http_client from .http_client import get_http_client
...@@ -46,6 +47,7 @@ class ImageLoader: ...@@ -46,6 +47,7 @@ class ImageLoader:
self._image_cache: dict[str, Image.Image] = {} self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size) self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
@_nvtx.annotate("mm:img:load_image", color="lime")
async def load_image(self, image_url: str) -> Image.Image: async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url) parsed_url = urlparse(image_url)
...@@ -58,44 +60,47 @@ class ImageLoader: ...@@ -58,44 +60,47 @@ class ImageLoader:
try: try:
if parsed_url.scheme == "data": if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data> with _nvtx.annotate("mm:img:base64_decode", color="lime"):
if not parsed_url.path.startswith("image/"): # Parse data URL format: data:[<media type>][;base64],<data>
raise ValueError("Data URL must be an image type") if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1) # Split the path into media type and data
if ";base64" not in media_type: media_type, data = parsed_url.path.split(",", 1)
raise ValueError("Data URL must be base64 encoded") if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data) try:
image_data = BytesIO(image_bytes) image_bytes = base64.b64decode(data)
except binascii.Error as e: image_data = BytesIO(image_bytes)
raise ValueError(f"Invalid base64 encoding: {e}") except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"): elif parsed_url.scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout) with _nvtx.annotate("mm:img:http_fetch", color="lime"):
http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url) response = await http_client.get(image_url)
response.raise_for_status() response.raise_for_status()
if not response.content: if not response.content:
raise ValueError("Empty response content from image URL") raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content) image_data = BytesIO(response.content)
else: else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}") raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop with _nvtx.annotate("mm:img:pil_open_convert", color="lime"):
# Restrict to supported formats to prevent PSD parsing (GHSA-cfh3-3jmp-rvhc) # PIL is sync, so offload to a thread to avoid blocking the event loop
image = await asyncio.to_thread( # Restrict to supported formats to prevent PSD parsing (GHSA-cfh3-3jmp-rvhc)
Image.open, image_data, formats=["JPEG", "PNG", "WEBP"] image = await asyncio.to_thread(
) Image.open, image_data, formats=["JPEG", "PNG", "WEBP"]
)
# Validate image format and convert to RGB # Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"): if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}") raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB") image_converted = image.convert("RGB")
# Cache HTTP(S) URLs # Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"): if parsed_url.scheme in ("http", "https"):
......
...@@ -9,6 +9,7 @@ Dynamo backends and components. ...@@ -9,6 +9,7 @@ Dynamo backends and components.
Submodules: Submodules:
- endpoint_types: Endpoint type parsing utilities - endpoint_types: Endpoint type parsing utilities
- nvtx_utils: NVTX profiling wrappers (enable with DYN_NVTX=1; no-ops by default)
- otel_tracing: OpenTelemetry tracing header utilities - otel_tracing: OpenTelemetry tracing header utilities
- paths: Workspace directory detection and path utilities - paths: Workspace directory detection and path utilities
- prometheus: Prometheus metrics collection and logging utilities - prometheus: Prometheus metrics collection and logging utilities
...@@ -18,6 +19,7 @@ from dynamo.common.utils import ( ...@@ -18,6 +19,7 @@ from dynamo.common.utils import (
endpoint_types, endpoint_types,
engine_response, engine_response,
namespace, namespace,
nvtx_utils,
otel_tracing, otel_tracing,
paths, paths,
prometheus, prometheus,
...@@ -28,6 +30,7 @@ __all__ = [ ...@@ -28,6 +30,7 @@ __all__ = [
"endpoint_types", "endpoint_types",
"engine_response", "engine_response",
"namespace", "namespace",
"nvtx_utils",
"otel_tracing", "otel_tracing",
"paths", "paths",
"prometheus", "prometheus",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Lightweight NVTX wrappers for Dynamo profiling.
Set DYN_NVTX=1 to enable markers; default is disabled (zero overhead).
Usage — same syntax as the bare nvtx module:
from dynamo.common.utils import nvtx_utils as _nvtx
# Manual range (needed when the range spans async yields or has conditional end)
rng = _nvtx.start_range("my:range", color="blue")
...
_nvtx.end_range(rng)
# Decorator — annotates an entire function or async generator
@_nvtx.annotate("my:func", color="green")
def my_func(): ...
@_nvtx.range_decorator("my:async_gen", color="green")
async def my_async_gen():
yield ...
# Context manager — annotates a block (works with await and yield inside)
with _nvtx.annotate("my:block", color="cyan"):
result = await some_coroutine()
When enabled, uses a named nvtx.Domain and pre-allocated EventAttributes
objects (cached lazily by (message, color)) so that repeated calls to
start_range incur only a single dict lookup — no object allocation
or domain cache lookups on the hot path.
"""
import functools
import inspect
import os
ENABLED: bool = bool(int(os.getenv("DYN_NVTX", "0")))
if ENABLED:
import nvtx as _nvtx_lib
# Named domain + pre-allocated EventAttributes: no per-call object
# allocation or domain cache lookups on the hot path.
_domain = _nvtx_lib.get_domain("dynamo")
_attr_cache: dict = {}
def _get_attr(message: str, color: str):
try:
return _attr_cache[message, color]
except KeyError:
attr = _domain.get_event_attributes(message=message, color=color)
_attr_cache[message, color] = attr
return attr
def start_range(message: str, color: str = "white"):
return _domain.start_range(_get_attr(message, color))
def end_range(rng) -> None:
_domain.end_range(rng)
# functools.partial so decorator and context-manager usage both land
# in the "dynamo" domain, keeping all markers in one nsys row.
annotate = functools.partial(_nvtx_lib.annotate, domain="dynamo")
def range_decorator(message: str, color: str = "white"):
"""Decorator that wraps an async generator function with an NVTX range.
Unlike annotate(), which only covers the synchronous setup before the
first yield, this wraps the full generator iteration in a single range.
"""
def decorator(func):
if inspect.isasyncgenfunction(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
rng = start_range(message, color)
try:
async for item in func(*args, **kwargs):
yield item
finally:
end_range(rng)
return wrapper
else:
@functools.wraps(func)
def wrapper(*args, **kwargs):
rng = start_range(message, color)
try:
return func(*args, **kwargs)
finally:
end_range(rng)
return wrapper
return decorator
else:
# Pure Python no-ops: no C extension calls, no string allocations.
# The ENV var is read once at import time — no per-call branch overhead.
def start_range(message: str, color: str = "white"): # type: ignore[misc]
return None
def end_range(rng) -> None: # type: ignore[misc]
pass
class _NoOpAnnotate:
"""No-op that works as both a decorator and a context manager."""
__slots__ = ()
def __call__(self, func):
return func
def __enter__(self):
return self
def __exit__(self, *args):
pass
_noop_annotate = _NoOpAnnotate()
def annotate(message: str = "", color: str = "white"): # type: ignore[misc]
return _noop_annotate
def range_decorator(message: str = "", color: str = "white"): # type: ignore[misc]
"""No-op decorator: returns the wrapped function unchanged."""
def decorator(func):
return func
return decorator
...@@ -18,6 +18,7 @@ from dynamo.common.multimodal import ( ...@@ -18,6 +18,7 @@ from dynamo.common.multimodal import (
NixlReadEmbeddingSender, NixlReadEmbeddingSender,
NixlWriteEmbeddingSender, NixlWriteEmbeddingSender,
) )
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from ..constants import EmbeddingTransferMode from ..constants import EmbeddingTransferMode
...@@ -119,6 +120,7 @@ class EncodeWorkerHandler: ...@@ -119,6 +120,7 @@ class EncodeWorkerHandler:
self._connector = connect.Connector() self._connector = connect.Connector()
logger.info("Encode worker startup completed.") logger.info("Encode worker startup completed.")
@_nvtx.range_decorator("mm:encode_worker_generate", color="blue")
async def generate( async def generate(
self, request: vLLMMultimodalRequest, context self, request: vLLMMultimodalRequest, context
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
...@@ -144,95 +146,106 @@ class EncodeWorkerHandler: ...@@ -144,95 +146,106 @@ class EncodeWorkerHandler:
try: try:
time_start = time.perf_counter() time_start = time.perf_counter()
# Before batch process images, check cache first
need_encode_indexes = [] with _nvtx.annotate("mm:enc:cache_check", color="cyan"):
embedding_lists = [None] * len(request.multimodal_inputs) # Before batch process images, check cache first
for idx in range(len(request.multimodal_inputs)): need_encode_indexes = []
if not request.multimodal_inputs[idx].multimodal_input.image_url: embedding_lists = [None] * len(request.multimodal_inputs)
raise ValueError("image_url is required for the encode worker.") for idx in range(len(request.multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url:
image_url = request.multimodal_inputs[idx].multimodal_input.image_url raise ValueError("image_url is required for the encode worker.")
# see if we have local cache
embedding_key = EmbeddingCache.generate_hash_key(image_url) image_url = request.multimodal_inputs[
if self.embedding_cache is not None and self.embedding_cache.has_key( idx
embedding_key ].multimodal_input.image_url
): # see if we have local cache
(image_grid_thw, embeddings) = self.embedding_cache.get( embedding_key = EmbeddingCache.generate_hash_key(image_url)
embedding_key if (
) self.embedding_cache is not None
embedding_lists[idx] = EmbeddingItem( and self.embedding_cache.has_key(embedding_key)
embedding_key, image_grid_thw, embeddings ):
(image_grid_thw, embeddings) = self.embedding_cache.get(
embedding_key
)
embedding_lists[idx] = EmbeddingItem(
embedding_key, image_grid_thw, embeddings
)
# compute
else:
# keep track of key to avoid recompute of it
need_encode_indexes.append((idx, embedding_key))
with _nvtx.annotate("mm:enc:image_load", color="green"):
# Load and generate image tensors
image_tasks = []
image_to_load = []
for idx, _ in need_encode_indexes:
url = request.multimodal_inputs[idx].multimodal_input.image_url
image_tasks.append(
asyncio.create_task(self.image_loader.load_image(url))
) )
# compute image_to_load.append(url)
else: results = await asyncio.gather(*image_tasks, return_exceptions=True)
# keep track of key to avoid recompute of it loaded_images = []
need_encode_indexes.append((idx, embedding_key)) collective_exceptions = ""
for i, result in enumerate(results):
# Load and generate image tensors if isinstance(result, Exception):
image_tasks = [] url = image_to_load[i]
image_to_load = [] logger.error(
for idx, _ in need_encode_indexes: f"Failed to load image from {url[:80]}...: {result}"
url = request.multimodal_inputs[idx].multimodal_input.image_url )
image_tasks.append( collective_exceptions += (
asyncio.create_task(self.image_loader.load_image(url)) f"Failed to load image from {url[:80]}...: {result}\n"
) )
image_to_load.append(url) continue
results = await asyncio.gather(*image_tasks, return_exceptions=True) loaded_images.append(result)
loaded_images = [] if collective_exceptions:
collective_exceptions = "" raise ValueError(
for i, result in enumerate(results): f"Errors occurred during image loading:\n{collective_exceptions}"
if isinstance(result, Exception):
url = image_to_load[i]
logger.error(f"Failed to load image from {url[:80]}...: {result}")
collective_exceptions += (
f"Failed to load image from {url[:80]}...: {result}\n"
) )
continue
loaded_images.append(result)
if collective_exceptions:
raise ValueError(
f"Errors occurred during image loading:\n{collective_exceptions}"
)
if loaded_images: if loaded_images:
image_embeds = await asyncio.to_thread( with _nvtx.annotate("mm:enc:image_preprocess", color="yellow"):
self.image_processor, images=loaded_images, return_tensors="pt" image_embeds = await asyncio.to_thread(
) self.image_processor, images=loaded_images, return_tensors="pt"
)
# Encode the image embeddings using model-specific encoder with _nvtx.annotate("mm:enc:vision_encode", color="red"):
embeddings = await asyncio.to_thread( # Encode the image embeddings using model-specific encoder
encode_image_embeddings, embeddings = await asyncio.to_thread(
model_name=self.model, encode_image_embeddings,
image_embeds=image_embeds, model_name=self.model,
vision_encoder=self.vision_encoder, image_embeds=image_embeds,
projector=self.projector, vision_encoder=self.vision_encoder,
) projector=self.projector,
)
# [gluo FIXME] This is specific to qwen vision processing.. with _nvtx.annotate("mm:enc:split_embeddings", color="orange"):
# Split concatenated embeddings for each image item. # [gluo FIXME] This is specific to qwen vision processing..
if is_qwen_vl_model(self.model): # Split concatenated embeddings for each image item.
merge_size = self.vision_encoder.spatial_merge_size if is_qwen_vl_model(self.model):
sizes = ( merge_size = self.vision_encoder.spatial_merge_size
image_embeds["image_grid_thw"].prod(-1) sizes = (
// merge_size image_embeds["image_grid_thw"].prod(-1)
// merge_size // merge_size
).tolist() // merge_size
splitted_embeddings = embeddings.squeeze(0).split(sizes) ).tolist()
logger.debug( splitted_embeddings = embeddings.squeeze(0).split(sizes)
f"Splitted embeddings lengths: {[e.shape for e in splitted_embeddings]}" logger.debug(
f"Splitted embeddings lengths: {[e.shape for e in splitted_embeddings]}"
)
else:
# Validated on llava (NOTE need to double check on other models) that the
# embeddings already has batch dimension for images, so we can directly
# split by batch dimension
logger.debug(f"image embedding shape: {embeddings.shape}")
splitted_embeddings = embeddings
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
) )
else:
# Validated on llava (NOTE need to double check on other models) that the
# embeddings already has batch dimension for images, so we can directly
# split by batch dimension
logger.debug(f"image embedding shape: {embeddings.shape}")
splitted_embeddings = embeddings
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
# fill in the embedding_lists with new computed embeddings and cache them # fill in the embedding_lists with new computed embeddings and cache them
for split_idx, (list_idx, key) in enumerate(need_encode_indexes): for split_idx, (list_idx, key) in enumerate(need_encode_indexes):
...@@ -253,38 +266,41 @@ class EncodeWorkerHandler: ...@@ -253,38 +266,41 @@ class EncodeWorkerHandler:
before_transfer_time = time.perf_counter() before_transfer_time = time.perf_counter()
# Prepare transfer with _nvtx.annotate("mm:enc:embedding_transfer", color="purple"):
send_tasks = [ # Prepare transfer
asyncio.create_task( send_tasks = [
self.embedding_sender.send_embeddings( asyncio.create_task(
embedding_item.embeddings, stage_embeddings=True self.embedding_sender.send_embeddings(
embedding_item.embeddings, stage_embeddings=True
)
) )
) for embedding_item in embedding_lists
for embedding_item in embedding_lists ]
] transfer_requests = await asyncio.gather(*send_tasks)
transfer_requests = await asyncio.gather(*send_tasks)
after_transfer_time = time.perf_counter() after_transfer_time = time.perf_counter()
for idx, item in enumerate(zip(embedding_lists, transfer_requests)): for idx, item in enumerate(zip(embedding_lists, transfer_requests)):
embedding_item, transfer_request = item embedding_item, transfer_request = item
logger.debug( logger.debug(
f"{embedding_item.embeddings.shape} prepared for transfer." f"{embedding_item.embeddings.shape} prepared for transfer."
) )
# Update request for transfer metadata # Update request for transfer metadata
request.multimodal_inputs[idx].multimodal_input.image_url = None request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[ request.multimodal_inputs[
idx idx
].image_grid_thw = embedding_item.image_grid_thw ].image_grid_thw = embedding_item.image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = tuple( request.multimodal_inputs[idx].embeddings_shape = tuple(
embedding_item.embeddings.shape embedding_item.embeddings.shape
) )
request.multimodal_inputs[idx].serialized_request = transfer_request[0] request.multimodal_inputs[
idx
].serialized_request = transfer_request[0]
# Keep a reference of the embedding and only drop reference when the transfer is done # Keep a reference of the embedding and only drop reference when the transfer is done
self.send_complete_queue.put_nowait( self.send_complete_queue.put_nowait(
(transfer_request[1], embedding_item.embeddings) (transfer_request[1], embedding_item.embeddings)
) )
logger.debug(f"Request: {request.model_dump_json()}") logger.debug(f"Request: {request.model_dump_json()}")
......
...@@ -20,6 +20,7 @@ from dynamo.common.multimodal.embedding_transfer import ( ...@@ -20,6 +20,7 @@ from dynamo.common.multimodal.embedding_transfer import (
NixlReadEmbeddingReceiver, NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver, NixlWriteEmbeddingReceiver,
) )
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.runtime import Client, DistributedRuntime from dynamo.runtime import Client, DistributedRuntime
from ..args import Config from ..args import Config
...@@ -257,6 +258,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -257,6 +258,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self, self,
request: vLLMMultimodalRequest, request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any], multi_modal_data: dict[str, Any],
rng_ttft=None,
): ):
"""Run prefill and decode on this worker (aggregated mode).""" """Run prefill and decode on this worker (aggregated mode)."""
lora_request = self._resolve_lora_request(request.model) lora_request = self._resolve_lora_request(request.model)
...@@ -271,14 +273,26 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -271,14 +273,26 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
) )
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
async for response in gen: first_token = True
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}") try:
logger.debug( async for response in gen:
f"length of expanded prompt ids: {len(response.prompt_token_ids)}" if first_token:
) if rng_ttft is not None:
yield self._format_engine_output(response, num_output_tokens_so_far) _nvtx.end_range(rng_ttft)
if response.outputs: first_token = False
num_output_tokens_so_far = len(response.outputs[0].token_ids) logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
logger.debug(
f"length of expanded prompt ids: {len(response.prompt_token_ids)}"
)
yield self._format_engine_output(response, num_output_tokens_so_far)
if response.outputs:
num_output_tokens_so_far = len(response.outputs[0].token_ids)
finally:
if first_token:
if rng_ttft is not None:
_nvtx.end_range(rng_ttft)
# ── Disaggregated generation (prefill here, decode remote) ─────── # ── Disaggregated generation (prefill here, decode remote) ───────
...@@ -286,6 +300,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -286,6 +300,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self, self,
request: vLLMMultimodalRequest, request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any], multi_modal_data: dict[str, Any],
rng_ttft=None,
): ):
"""Prefill locally, then forward to a remote decode worker.""" """Prefill locally, then forward to a remote decode worker."""
# Prepare prefill-only request # Prepare prefill-only request
...@@ -298,19 +313,24 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -298,19 +313,24 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
logger.debug("Prefill request: %s", prefill_only_request) logger.debug("Prefill request: %s", prefill_only_request)
lora_request = self._resolve_lora_request(request.model) lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate( with _nvtx.annotate("mm:pd:disagg_prefill", color="darkred"):
prompt=TokensPrompt( gen = self.engine_client.generate(
prompt_token_ids=prefill_only_request.engine_prompt["prompt_token_ids"], prompt=TokensPrompt(
multi_modal_data=multi_modal_data, prompt_token_ids=prefill_only_request.engine_prompt[
), "prompt_token_ids"
sampling_params=prefill_only_request.sampling_params, ],
request_id=prefill_only_request.request_id, multi_modal_data=multi_modal_data,
lora_request=lora_request, ),
) sampling_params=prefill_only_request.sampling_params,
request_id=prefill_only_request.request_id,
lora_request=lora_request,
)
# Drain prefill generator (max_tokens=1, expect a single response) # Drain prefill generator (max_tokens=1, expect a single response)
async for prefill_response in gen: async for prefill_response in gen:
pass pass
if rng_ttft is not None:
_nvtx.end_range(rng_ttft)
# Qwen VL (mRoPE): keep the ORIGINAL unexpanded prompt. # Qwen VL (mRoPE): keep the ORIGINAL unexpanded prompt.
# The decode worker passes multi_modal_data which causes vLLM to # The decode worker passes multi_modal_data which causes vLLM to
...@@ -347,32 +367,49 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -347,32 +367,49 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
f"— ensure the same adapter is loaded on the decode worker." f"— ensure the same adapter is loaded on the decode worker."
) )
num_output_tokens_so_far = 0 with _nvtx.annotate("mm:pd:disagg_remote_decode", color="purple"):
async for ( num_output_tokens_so_far = 0
decode_response async for (
) in await self.decode_worker_client.round_robin( # type: ignore[union-attr] decode_response
request.model_dump_json() ) in await self.decode_worker_client.round_robin( # type: ignore[union-attr]
): request.model_dump_json()
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore[attr-defined] ):
yield self._format_engine_output(output, num_output_tokens_so_far) output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore[attr-defined]
if output.outputs: yield self._format_engine_output(output, num_output_tokens_so_far)
num_output_tokens_so_far = len(output.outputs[0].token_ids) if output.outputs:
num_output_tokens_so_far = len(output.outputs[0].token_ids)
# ── Public entry point ─────────────────────────────────────────── # ── Public entry point ───────────────────────────────────────────
async def generate(self, raw_request: dict, context): async def generate(self, raw_request: dict, context):
"""Parse the request, load multimodal data, and run inference.""" """Parse the request, load multimodal data, and run inference."""
rng_pd = _nvtx.start_range("mm:pd_worker_generate", color="green")
rng_ttft = _nvtx.start_range("mm:pd:ttft", color="orange")
rng_parse = _nvtx.start_range("mm:pd:parse_request", color="cyan")
request, image_urls = self._parse_frontend_request(raw_request) request, image_urls = self._parse_frontend_request(raw_request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
_nvtx.end_range(rng_parse)
rng_load = _nvtx.start_range("mm:pd:load_multimodal", color="yellow")
multi_modal_data = await self._load_multimodal_data( multi_modal_data = await self._load_multimodal_data(
image_urls, request.request_id image_urls, request.request_id
) )
_nvtx.end_range(rng_load)
self._finalize_request_metadata(request, multi_modal_data) self._finalize_request_metadata(request, multi_modal_data)
if self.enable_disagg and self.decode_worker_client: if self.enable_disagg and self.decode_worker_client:
async for chunk in self._generate_disagg(request, multi_modal_data): rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
async for chunk in self._generate_disagg(
request, multi_modal_data, rng_ttft
):
yield chunk yield chunk
_nvtx.end_range(rng_disagg)
else: else:
async for chunk in self._generate_agg(request, multi_modal_data): rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red")
async for chunk in self._generate_agg(request, multi_modal_data, rng_ttft):
yield chunk yield chunk
_nvtx.end_range(rng_agg)
_nvtx.end_range(rng_pd)
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from ..args import Config from ..args import Config
...@@ -53,6 +54,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -53,6 +54,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
logger.info("Multimodal Decode Worker async initialization completed.") logger.info("Multimodal Decode Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context): async def generate(self, request: vLLMMultimodalRequest, context):
rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue")
logger.debug(f"Got raw request: {request}") logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest): if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str): if isinstance(request, str):
...@@ -95,15 +97,27 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -95,15 +97,27 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
lora_request=lora_request, lora_request=lora_request,
) )
async for response in gen: rng_first = _nvtx.start_range("mm:decode:first_token", color="darkred")
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}") first_token = True
yield MyRequestOutput( try:
request_id=response.request_id, async for response in gen:
prompt=response.prompt, if first_token:
prompt_token_ids=response.prompt_token_ids, _nvtx.end_range(rng_first)
prompt_logprobs=response.prompt_logprobs, first_token = False
outputs=response.outputs, logger.debug(
finished=response.finished, f"Response kv_transfer_params: {response.kv_transfer_params}"
metrics=response.metrics, )
kv_transfer_params=response.kv_transfer_params, yield MyRequestOutput(
).model_dump_json() request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
finally:
if first_token:
_nvtx.end_range(rng_first)
_nvtx.end_range(rng_decode)
...@@ -127,7 +127,7 @@ def encode_image_embeddings( ...@@ -127,7 +127,7 @@ def encode_image_embeddings(
embeddings = embeddings[0] embeddings = embeddings[0]
embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings
return embeddings return embeddings
def get_encoder_components( def get_encoder_components(
......
...@@ -30,6 +30,7 @@ msgpack==1.1.2 ...@@ -30,6 +30,7 @@ msgpack==1.1.2
msgspec==0.19.0 msgspec==0.19.0
mypy==1.18.2 mypy==1.18.2
nvidia-ml-py<=13.580.65 # NVIDIA/CUDA related, may vary by driver version nvidia-ml-py<=13.580.65 # NVIDIA/CUDA related, may vary by driver version
nvtx==0.2.14
opentelemetry-api<=1.38.0 # May need to stay in sync with other components opentelemetry-api<=1.38.0 # May need to stay in sync with other components
opentelemetry-exporter-otlp<=1.38.0 # May need to stay in sync with other components opentelemetry-exporter-otlp<=1.38.0 # May need to stay in sync with other components
opentelemetry-sdk<=1.38.0 # May need to stay in sync with other components opentelemetry-sdk<=1.38.0 # May need to stay in sync with other components
......
...@@ -611,6 +611,38 @@ curl -X POST http://<decode-worker>/load_lora \ ...@@ -611,6 +611,38 @@ curl -X POST http://<decode-worker>/load_lora \
If a LoRA is loaded on the prefill worker but not on the decode worker, the decode worker will fall back to the base model for that request. If a LoRA is loaded on the prefill worker but not on the decode worker, the decode worker will fall back to the base model for that request.
## Profiling
Dynamo's multimodal workers include NVTX markers for `nsys` profiling. They are disabled by default (zero overhead) and enabled by setting `DYN_NVTX=1`.
```bash
cd $DYNAMO_HOME/examples/backends/vllm
DYN_NVTX=1 nsys profile --trace=cuda,nvtx -o profile.nsys-rep \
bash launch/agg_multimodal.sh ...
```
| ENV Variable | Default | Description |
|---|---|---|
| `DYN_NVTX` | `0` | Set to `1` to enable NVTX range/mark annotations in encode, prefill, and decode workers for `nsys` profiling |
Key NVTX ranges emitted:
| Range | Worker | Description |
|-------|--------|-------------|
| `mm:encode_worker_generate` | Encode | Full encode request lifetime |
| `mm:enc:cache_check` | Encode | Embedding cache lookup |
| `mm:enc:image_load` | Encode | Image download/load |
| `mm:enc:image_preprocess` | Encode | Image processor (CPU) |
| `mm:enc:vision_encode` | Encode | ViT + projector GPU forward |
| `mm:enc:embedding_transfer` | Encode | RDMA embedding staging |
| `mm:pd_worker_generate` | PD | Full PD request lifetime |
| `mm:pd:ttft` | PD | Worker-side TTFT: from request arrival at the PD worker to first output token (excludes client→frontend→worker network transit) |
| `mm:pd:load_multimodal` | PD | Fetch embeddings from encode worker |
| `mm:pd:disagg_prefill` | PD (disagg) | Prefill-only engine call |
| `mm:pd:disagg_remote_decode` | PD (disagg) | Remote decode round-trip |
| `mm:decode_worker_generate` | Decode | Full decode request lifetime |
| `mm:decode:first_token` | Decode | Time to first output token |
## Known Limitations ## Known Limitations
- **Disaggregated flows require Python Processor** - All multimodal disaggregation requires the Python Processor component (`ModelInput.Text`). - **Disaggregated flows require Python Processor** - All multimodal disaggregation requires the Python Processor component (`ModelInput.Text`).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment