Unverified Commit c82fe888 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat: add embedding cache to pd worker (#6061)

parent ebc61637
......@@ -19,13 +19,18 @@ Usage:
import logging
from collections import OrderedDict
from typing import Optional
from typing import NamedTuple, Optional
import torch
logger = logging.getLogger(__name__)
class CachedEmbedding(NamedTuple):
tensor: torch.Tensor
image_grid_thw: list | None = None
class MultimodalEmbeddingCacheManager:
"""
LRU cache for encoder embeddings.
......@@ -47,7 +52,7 @@ class MultimodalEmbeddingCacheManager:
Args:
capacity_bytes: Maximum cache capacity in bytes.
"""
self._cache: OrderedDict[str, torch.Tensor] = OrderedDict()
self._cache: OrderedDict[str, CachedEmbedding] = OrderedDict()
self._capacity_bytes = capacity_bytes
self._current_bytes = 0
......@@ -77,9 +82,9 @@ class MultimodalEmbeddingCacheManager:
), "Tensor must be contiguous for accurate size calculation"
return tensor.element_size() * tensor.numel()
def get(self, key: str) -> Optional[torch.Tensor]:
def get(self, key: str) -> Optional[CachedEmbedding]:
"""
Get a tensor from the cache.
Get a cached embedding from the cache.
If found, the entry is moved to the end (most recently used).
......@@ -87,7 +92,7 @@ class MultimodalEmbeddingCacheManager:
key: Cache key (typically content hash).
Returns:
The cached tensor, or None if not found.
The cached embedding, or None if not found.
"""
if key not in self._cache:
self._misses += 1
......@@ -98,22 +103,22 @@ class MultimodalEmbeddingCacheManager:
self._hits += 1
return self._cache[key]
def set(self, key: str, tensor: torch.Tensor) -> bool:
def set(self, key: str, entry: CachedEmbedding) -> bool:
"""
Store a tensor in the cache.
Store a cached embedding in the cache.
If the key already exists, the old value is replaced.
If adding the tensor would exceed capacity, LRU entries are evicted.
If adding the entry would exceed capacity, LRU entries are evicted.
If the tensor itself is larger than capacity, it is not stored.
Args:
key: Cache key (typically content hash).
tensor: Tensor to cache.
entry: CachedEmbedding to cache.
Returns:
True if the tensor was stored, False if it was too large.
True if the entry was stored, False if it was too large.
"""
size = self._tensor_size(tensor)
size = self._tensor_size(entry.tensor)
# Don't cache if single tensor exceeds capacity
if size > self._capacity_bytes:
......@@ -125,20 +130,20 @@ class MultimodalEmbeddingCacheManager:
# If key exists, remove old entry first
if key in self._cache:
old_tensor = self._cache.pop(key)
self._current_bytes -= self._tensor_size(old_tensor)
old_entry = self._cache.pop(key)
self._current_bytes -= self._tensor_size(old_entry.tensor)
# Evict LRU entries until we have space
while self._current_bytes + size > self._capacity_bytes and self._cache:
evicted_key, evicted_tensor = self._cache.popitem(last=False)
evicted_size = self._tensor_size(evicted_tensor)
evicted_key, evicted_entry = self._cache.popitem(last=False)
evicted_size = self._tensor_size(evicted_entry.tensor)
self._current_bytes -= evicted_size
logger.debug(
f"Evicted key={evicted_key[:16]}..., size={evicted_size / 1024**2:.2f}MB"
)
# Store new entry
self._cache[key] = tensor
self._cache[key] = entry
self._current_bytes += size
logger.debug(
......
......@@ -7,6 +7,7 @@ import pytest
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
......@@ -19,12 +20,25 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024) # 1MB
tensor = torch.randn(100, 100) # ~40KB for float32
result = cache.set("key1", tensor)
result = cache.set("key1", CachedEmbedding(tensor))
assert result is True
retrieved = cache.get("key1")
assert retrieved is not None
assert torch.equal(retrieved, tensor)
assert torch.equal(retrieved.tensor, tensor)
assert retrieved.image_grid_thw is None
def test_set_and_get_with_grid(self):
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(100, 100)
grid = [[1, 2, 3]]
cache.set("key1", CachedEmbedding(tensor, grid))
retrieved = cache.get("key1")
assert retrieved is not None
assert torch.equal(retrieved.tensor, tensor)
assert retrieved.image_grid_thw == grid
def test_get_nonexistent_key(self):
"""Test get returns None for nonexistent key."""
......@@ -39,11 +53,11 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
tensor1 = torch.randn(10, 10)
tensor2 = torch.randn(10, 10)
cache.set("key1", tensor1)
cache.set("key1", tensor2)
cache.set("key1", CachedEmbedding(tensor1))
cache.set("key1", CachedEmbedding(tensor2))
retrieved = cache.get("key1")
assert torch.equal(retrieved, tensor2)
assert torch.equal(retrieved.tensor, tensor2)
assert cache.stats["entries"] == 1
......@@ -61,11 +75,11 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
t2 = torch.randn(10, 10)
t3 = torch.randn(10, 10)
cache.set("key1", t1)
cache.set("key2", t2)
cache.set("key1", CachedEmbedding(t1))
cache.set("key2", CachedEmbedding(t2))
# Adding third should evict first (LRU)
cache.set("key3", t3)
cache.set("key3", CachedEmbedding(t3))
assert cache.get("key1") is None # Evicted
assert cache.get("key2") is not None
......@@ -81,14 +95,14 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
t2 = torch.randn(10, 10)
t3 = torch.randn(10, 10)
cache.set("key1", t1)
cache.set("key2", t2)
cache.set("key1", CachedEmbedding(t1))
cache.set("key2", CachedEmbedding(t2))
# Access key1, making key2 the LRU
cache.get("key1")
# Adding third should evict key2 (now LRU)
cache.set("key3", t3)
cache.set("key3", CachedEmbedding(t3))
assert cache.get("key1") is not None # Not evicted (recently accessed)
assert cache.get("key2") is None # Evicted (LRU)
......@@ -99,7 +113,7 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
cache = MultimodalEmbeddingCacheManager(capacity_bytes=100) # Very small
tensor = torch.randn(100, 100) # ~40KB, way larger than capacity
result = cache.set("key1", tensor)
result = cache.set("key1", CachedEmbedding(tensor))
assert result is False
assert cache.get("key1") is None
......@@ -119,10 +133,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking:
expected_size_1 = t1.element_size() * t1.numel()
expected_size_2 = t2.element_size() * t2.numel()
cache.set("key1", t1)
cache.set("key1", CachedEmbedding(t1))
assert cache.stats["current_bytes"] == expected_size_1
cache.set("key2", t2)
cache.set("key2", CachedEmbedding(t2))
assert cache.stats["current_bytes"] == expected_size_1 + expected_size_2
def test_size_updated_on_overwrite(self):
......@@ -132,10 +146,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking:
small_tensor = torch.randn(10, 10) # 400 bytes
large_tensor = torch.randn(20, 20) # 1600 bytes
cache.set("key1", small_tensor)
cache.set("key1", CachedEmbedding(small_tensor))
initial_size = cache.stats["current_bytes"]
cache.set("key1", large_tensor)
cache.set("key1", CachedEmbedding(large_tensor))
expected_size = large_tensor.element_size() * large_tensor.numel()
assert cache.stats["current_bytes"] == expected_size
......@@ -150,7 +164,7 @@ class TestMultimodalEmbeddingCacheManagerStats:
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10)
cache.set("key1", tensor)
cache.set("key1", CachedEmbedding(tensor))
# Misses
cache.get("nonexistent1")
......@@ -170,7 +184,7 @@ class TestMultimodalEmbeddingCacheManagerStats:
"""Test stats dictionary contains expected keys."""
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10)
cache.set("key1", tensor)
cache.set("key1", CachedEmbedding(tensor))
stats = cache.stats
......@@ -190,10 +204,9 @@ class TestMultimodalEmbeddingCacheManagerStats:
capacity = 1000
cache = MultimodalEmbeddingCacheManager(capacity_bytes=capacity)
# Create tensor of known size
# float32 = 4 bytes, so 25 elements = 100 bytes
tensor = torch.zeros(25, dtype=torch.float32)
cache.set("key1", tensor)
cache.set("key1", CachedEmbedding(tensor))
stats = cache.stats
expected_utilization = 100 / capacity
......@@ -209,7 +222,7 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor:
tensor = torch.randn(10, 10)
assert tensor.is_contiguous()
result = cache.set("key1", tensor)
result = cache.set("key1", CachedEmbedding(tensor))
assert result is True
def test_set_non_contiguous_tensor_raises(self):
......@@ -221,4 +234,29 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor:
assert not tensor.is_contiguous()
with pytest.raises(AssertionError, match="Tensor must be contiguous"):
cache.set("key1", tensor)
cache.set("key1", CachedEmbedding(tensor))
class TestCachedEmbeddingNamedTuple:
"""Tests for CachedEmbedding NamedTuple."""
def test_fields(self):
tensor = torch.randn(4, 4)
grid = [[1, 2, 3]]
entry = CachedEmbedding(tensor=tensor, image_grid_thw=grid)
assert torch.equal(entry.tensor, tensor)
assert entry.image_grid_thw == grid
def test_none_grid(self):
tensor = torch.randn(4, 4)
entry = CachedEmbedding(tensor=tensor, image_grid_thw=None)
assert entry.image_grid_thw is None
def test_unpacking(self):
tensor = torch.randn(4, 4)
grid = [[1, 2, 3]]
entry = CachedEmbedding(tensor=tensor, image_grid_thw=grid)
t, g = entry
assert torch.equal(t, tensor)
assert g == grid
......@@ -15,6 +15,7 @@ import torch
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.multimodal.cuda_ipc import extract_embeddings_from_handles
......@@ -148,7 +149,7 @@ async def _fetch_embeddings_with_cache(
cached = cache.get(url_hash)
if cached is not None:
logger.info(f"fetch_embeddings_with_cache: cache hit for URL: {url}")
embeddings_with_index.append((i, cached))
embeddings_with_index.append((i, cached.tensor))
else:
logger.info(f"fetch_embeddings_with_cache: cache miss for URL: {url}")
uncached_urls.append(url)
......@@ -189,7 +190,7 @@ async def _fetch_embeddings_with_cache(
# Cache new tensors (reuse hashes computed during cache lookup)
for url, url_hash, tensor in zip(uncached_urls, uncached_hashes, new_tensors):
cache.set(url_hash, tensor)
cache.set(url_hash, CachedEmbedding(tensor=tensor))
logger.info(
f"fetch_embeddings_with_cache: cached embedding for URL: {url}, shape: {tensor.shape}"
)
......
......@@ -18,6 +18,7 @@ if not torch.cuda.is_available():
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
......@@ -76,7 +77,10 @@ class TestFetchEmbeddingsFromEncoder:
url1, url2 = "http://example.com/img1.jpg", "http://example.com/img2.jpg"
embedding1, embedding2 = torch.ones(10, 256), torch.ones(10, 256) * 2
encoder_cache.set(MultimodalHasher.hash_bytes(url1.encode()), embedding1)
encoder_cache.set(
MultimodalHasher.hash_bytes(url1.encode()),
CachedEmbedding(tensor=embedding1),
)
request: dict[str, Any] = {"messages": []}
mock_client = create_mock_encode_client([embedding2])
......@@ -98,8 +102,14 @@ class TestFetchEmbeddingsFromEncoder:
url1, url2 = "http://example.com/img1.jpg", "http://example.com/img2.jpg"
embedding1, embedding2 = torch.ones(10, 256), torch.ones(10, 256) * 2
encoder_cache.set(MultimodalHasher.hash_bytes(url1.encode()), embedding1)
encoder_cache.set(MultimodalHasher.hash_bytes(url2.encode()), embedding2)
encoder_cache.set(
MultimodalHasher.hash_bytes(url1.encode()),
CachedEmbedding(tensor=embedding1),
)
encoder_cache.set(
MultimodalHasher.hash_bytes(url2.encode()),
CachedEmbedding(tensor=embedding2),
)
async def should_not_call(req: dict[str, Any]) -> None:
raise AssertionError("Should not be called")
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import copy
import logging
import os
......@@ -26,21 +25,16 @@ from dynamo.runtime import Client, Component, DistributedRuntime
from ..args import Config
from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import (
MultiModalGroup,
MyRequestOutput,
PatchedTokensPrompt,
vLLMMultimodalRequest,
)
from ..multimodal_utils.model import is_qwen_vl_model
from ..multimodal_utils.prefill_worker_utils import (
IMAGE_URL_KEY,
accumulate_embeddings,
fetch_embeddings_from_encode_workers,
load_embeddings,
)
from ..multimodal_utils.prefill_worker_utils import load_multimodal_embeddings
logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url"
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
......@@ -96,8 +90,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
else:
self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
......@@ -120,19 +112,18 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self._connector = connect.Connector()
logger.info("Multimodal PD Worker async initialization completed.")
async def _build_request_from_frontend(
def _parse_frontend_request(
self, raw_request: dict
) -> vLLMMultimodalRequest:
"""Convert a raw frontend dict into a vLLMMultimodalRequest.
) -> tuple[vLLMMultimodalRequest, list[str]]:
"""Parse a raw frontend dict into a vLLMMultimodalRequest and image URLs.
When the PD worker is the direct frontend endpoint (no separate
processor), the Rust frontend sends a dict representation of PreprocessedRequest.
This method extracts image URLs, routes them to encode workers if available,
and assembles the standard request object that the rest of ``generate()`` expects.
The Rust frontend sends a dict with ``token_ids`` and
``multi_modal_data`` (containing image URLs). This method extracts
those fields into a structured request. No I/O is performed here;
embedding fetching is handled separately by ``_load_multimodal_data``.
"""
request_id = str(uuid.uuid4().hex)
# Extract image URLs from the raw frontend dict
image_urls: list[str] = []
mm_data = raw_request.get("multi_modal_data")
if mm_data is not None:
......@@ -140,88 +131,43 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if isinstance(item, dict) and "Url" in item:
image_urls.append(item["Url"])
multimodal_groups: list[MultiModalGroup] = []
if self.encode_worker_client and image_urls:
multimodal_groups = await fetch_embeddings_from_encode_workers(
self.encode_worker_client,
image_urls,
request_id,
)
sampling_params = build_sampling_params(
raw_request, self.default_sampling_params
)
return vLLMMultimodalRequest(
request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(
prompt_token_ids=raw_request["token_ids"]
),
sampling_params=sampling_params,
request_id=request_id,
model=raw_request.get("model"),
multimodal_inputs=multimodal_groups,
)
# ── Request parsing ────────────────────────────────────────────────
async def _parse_request(self, request) -> vLLMMultimodalRequest:
"""Normalize any incoming format into a validated vLLMMultimodalRequest.
Handles three input shapes:
1. Raw frontend dict (has ``token_ids`` + ``multi_modal_data``)
2. JSON string (from encode worker or other serializers)
3. Plain dict (Pydantic-compatible mapping)
"""
if isinstance(request, dict) and "token_ids" in request:
return await self._build_request_from_frontend(request)
if type(request) is vLLMMultimodalRequest:
return request
if type(request) is str:
return vLLMMultimodalRequest.model_validate_json(request)
return vLLMMultimodalRequest.model_validate(request)
return request, image_urls
# ── Multimodal data loading ──────────────────────────────────────
async def _load_multimodal_data(
self, request: vLLMMultimodalRequest
) -> tuple[dict[str, Any], list[int]]:
"""Load pre-computed embeddings into an engine-ready dict.
Each ``MultiModalGroup`` carries embeddings from encode workers,
loaded via NIXL RDMA or local safetensors.
self, image_urls: list[str], request_id: str
) -> dict[str, Any]:
"""Fetch embeddings from encode workers and load into an engine-ready dict.
No-op when --route-to-encoder is not set.
Returns an empty dict when no encode worker is configured or no images
are present.
"""
multimodal_inputs: list[MultiModalGroup] = request.multimodal_inputs or []
multi_modal_data: dict[str, Any] = defaultdict(list)
task_lists = [
asyncio.create_task(
load_embeddings(
mi,
self.EMBEDDINGS_DTYPE,
self.EMBEDDINGS_DEVICE,
self.embedding_receiver,
)
)
for mi in multimodal_inputs
]
receiver_tensor_ids: list[int] = []
for task, mi in zip(task_lists, multimodal_inputs):
tensor_id, embeddings = await task
receiver_tensor_ids.append(tensor_id)
accumulate_embeddings(
multi_modal_data,
self.config.model,
self.EMBEDDINGS_DTYPE,
embeddings,
mi.image_grid_thw,
)
return multi_modal_data, receiver_tensor_ids
if not self.encode_worker_client or not image_urls:
return defaultdict(list)
return await load_multimodal_embeddings(
self.encode_worker_client, # type: ignore[arg-type]
image_urls,
request_id,
self.embedding_receiver,
model=self.config.model,
embeddings_dtype=self.EMBEDDINGS_DTYPE,
cache=self.embedding_cache_manager,
)
# ── Request metadata finalization ────────────────────────────────
......@@ -230,14 +176,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any],
) -> None:
"""Attach model-specific metadata and strip heavy fields from request.
"""Attach model-specific metadata to the request for the decode worker.
For Qwen VL (mRoPE) models, captures image grid dimensions and
embedding shapes so the decode worker can reconstruct
``multi_modal_data`` consistently for multiple images.
Also clears ``multimodal_inputs`` — the raw embeddings / URLs are no
longer needed once ``multi_modal_data`` is built.
"""
if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict
......@@ -254,11 +197,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if image_embeds is not None:
request.embeddings_shape = list(image_embeds.shape)
# Use empty list instead of None to satisfy Pydantic validation
# on decode worker after vllm upgrade.
request.multimodal_inputs = []
logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
logger.debug(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
logger.debug("Multimodal data keys: %s", list(multi_modal_data.keys()))
# ── Response serialization ───────────────────────────────────────
......@@ -318,7 +257,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self,
request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any],
received_tensor_ids: list[int],
):
"""Run prefill and decode on this worker (aggregated mode)."""
lora_request = self._resolve_lora_request(request.model)
......@@ -332,9 +270,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
lora_request=lora_request,
)
for tensor_id in received_tensor_ids:
self.embedding_receiver.release_tensor(tensor_id)
num_output_tokens_so_far = 0
async for response in gen:
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
......@@ -351,7 +286,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self,
request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any],
received_tensor_ids: list[int],
):
"""Prefill locally, then forward to a remote decode worker."""
# Prepare prefill-only request
......@@ -374,9 +308,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
lora_request=lora_request,
)
for tensor_id in received_tensor_ids:
self.embedding_receiver.release_tensor(tensor_id)
# Drain prefill generator (max_tokens=1, expect a single response)
async for prefill_response in gen:
pass
......@@ -415,6 +346,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
f"Forwarding disaggregated decode with LoRA '{request.model}' "
f"— ensure the same adapter is loaded on the decode worker."
)
async for (
decode_response
) in await self.decode_worker_client.round_robin( # type: ignore[union-attr]
......@@ -425,30 +357,19 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# ── Public entry point ───────────────────────────────────────────
async def generate(self, request, context):
async def generate(self, raw_request: dict, context):
"""Parse the request, load multimodal data, and run inference."""
logger.debug(f"Got raw request: {request}")
request = await self._parse_request(request)
request, image_urls = self._parse_frontend_request(raw_request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
multi_modal_data, received_tensor_ids = await self._load_multimodal_data(
request
multi_modal_data = await self._load_multimodal_data(
image_urls, request.request_id
)
self._finalize_request_metadata(request, multi_modal_data)
logger.info(
f"Prepared multimodal data size: {len(multi_modal_data.get('image', []))}"
)
logger.debug(f"{multi_modal_data}")
if self.enable_disagg and self.decode_worker_client:
async for chunk in self._generate_disagg(
request, multi_modal_data, received_tensor_ids
):
async for chunk in self._generate_disagg(request, multi_modal_data):
yield chunk
else:
async for chunk in self._generate_agg(
request, multi_modal_data, received_tensor_ids
):
async for chunk in self._generate_agg(request, multi_modal_data):
yield chunk
......@@ -19,11 +19,7 @@ from dynamo.vllm.multimodal_utils.model import (
construct_mm_data,
load_vision_model,
)
from dynamo.vllm.multimodal_utils.prefill_worker_utils import (
accumulate_embeddings,
fetch_embeddings_from_encode_workers,
load_embeddings,
)
from dynamo.vllm.multimodal_utils.prefill_worker_utils import load_multimodal_embeddings
from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup,
MultiModalInput,
......@@ -52,7 +48,5 @@ __all__ = [
"MultiModalRequest",
"MyRequestOutput",
"vLLMMultimodalRequest",
"accumulate_embeddings",
"fetch_embeddings_from_encode_workers",
"load_embeddings",
"load_multimodal_embeddings",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import os
from collections import defaultdict
from typing import Any, Dict, List
import torch
from vllm.sampling_params import SamplingParams as VllmSamplingParams
from dynamo.common.multimodal.embedding_transfer import AbstractEmbeddingReceiver
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
LocalEmbeddingReceiver,
)
from dynamo.runtime import Client
from .encode_utils import get_embedding_hash
from .model import construct_mm_data
from .protocol import (
MultiModalGroup,
......@@ -21,39 +31,37 @@ from .protocol import (
logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url"
VIDEO_URL_KEY = "video_url"
SPLIT_ENCODE = int(os.getenv("DYN_SPLIT_ENCODE", 1))
# Whether to split the multimodal items into smaller batches for encoding. This can help if multimodal items can be speed up
# by separately encodeded with multiple workers.
# Need to experiment with this setting to see if it brings benefits when concurrency > encoder count.
SPLIT_ENCODE = int(os.getenv("SPLIT_ENCODE", 1))
# ── Internal helpers (all underscore-prefixed) ───────────────────────
async def load_embeddings(
mi: MultiModalGroup,
_embeddings_dtype: torch.dtype,
_embeddings_device: str,
receiver: AbstractEmbeddingReceiver,
) -> tuple[int, torch.Tensor]:
"""Load pre-computed embedding tensor via local safetensors or NIXL RDMA.
Args:
mi: A single MultiModalGroup whose ``serialized_request`` field
contains either a local file path or NIXL RDMA metadata.
embeddings_dtype: Torch dtype for the tensor (used for RDMA path).
embeddings_device: Device string for the tensor (used for RDMA path).
receiver: AbstractEmbeddingReceiver for tensor reads.
Returns:
A tuple of (tensor_id, embeddings), where tensor_id is an integer identifier for the loaded tensor (used for later release),
and the embeddings tensor loaded into CPU memory.
class _PendingRelease:
"""Tracks NIXL tensor buffers that should be released after consumption.
For NIXL receivers, embeddings are views into pre-allocated reusable
buffers. Instead of cloning each embedding eagerly, we defer the
release until the caller has consumed the tensors (e.g. via
``_accumulate_embeddings`` which copies data through ``torch.cat``).
"""
tensor_id, embeddings = await receiver.receive_embeddings(mi.serialized_request)
return tensor_id, embeddings
__slots__ = ("_receiver", "_tensor_ids")
def __init__(self, receiver: AbstractEmbeddingReceiver):
self._receiver = receiver
self._tensor_ids: List[int] = []
def track(self, tensor_id: int) -> None:
self._tensor_ids.append(tensor_id)
def accumulate_embeddings(
def release_all(self) -> None:
for tid in self._tensor_ids:
self._receiver.release_tensor(tid)
self._tensor_ids.clear()
def _accumulate_embeddings(
multi_modal_data: Dict[str, Any],
model: str,
embeddings_dtype: torch.dtype,
......@@ -113,16 +121,32 @@ def accumulate_embeddings(
)
async def fetch_embeddings_from_encode_workers(
def _ensure_owned_tensors(multi_modal_data: Dict[str, Any]) -> None:
"""Clone tensor views so NIXL buffers can be safely released.
Only needed for single-image; multi-image goes through torch.cat
which already produces owned tensors.
"""
img = multi_modal_data.get("image")
if isinstance(img, dict):
for k, v in img.items():
if isinstance(v, torch.Tensor):
img[k] = v.clone()
elif isinstance(img, torch.Tensor):
multi_modal_data["image"] = img.clone()
async def _fetch_from_encode_workers(
encode_worker_client: Client,
image_urls: List[str],
request_id: str,
) -> List[MultiModalGroup]:
"""Fan out image URLs to encode workers and collect embedding results.
receiver: AbstractEmbeddingReceiver,
) -> tuple[List[MultiModalGroup], _PendingRelease | None]:
"""Fan out image URLs to encode workers, load embeddings, and return ready groups.
Splits image URLs into batches based on available encode worker count,
dispatches via round-robin, and collects the resulting MultiModalGroups
containing pre-computed embeddings.
For NIXL receivers the returned embeddings are zero-copy views into
pre-allocated buffers. The returned ``_PendingRelease`` must be
released after the tensors have been consumed.
"""
encode_worker_count = len(encode_worker_client.instance_ids())
if encode_worker_count == 0:
......@@ -156,7 +180,6 @@ async def fetch_embeddings_from_encode_workers(
)
batch = []
# Flush remaining
if batch:
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
......@@ -164,7 +187,6 @@ async def fetch_embeddings_from_encode_workers(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
)
# Collect results
multimodal_groups: List[MultiModalGroup] = []
for stream in encode_response_streams:
async for response in stream:
......@@ -173,4 +195,135 @@ async def fetch_embeddings_from_encode_workers(
if output.multimodal_inputs:
multimodal_groups.extend(output.multimodal_inputs)
return multimodal_groups
tasks = [
asyncio.create_task(receiver.receive_embeddings(group.serialized_request))
for group in multimodal_groups
]
loaded = await asyncio.gather(*tasks)
is_local = isinstance(receiver, LocalEmbeddingReceiver)
pending: _PendingRelease | None = None if is_local else _PendingRelease(receiver)
for group, (tensor_id, embedding) in zip(multimodal_groups, loaded, strict=True):
group.loaded_embedding = embedding
if pending is not None:
pending.track(tensor_id)
return multimodal_groups, pending
async def _fetch_embeddings(
encode_worker_client: Client,
image_urls: list[str],
request_id: str,
receiver: AbstractEmbeddingReceiver,
cache: MultimodalEmbeddingCacheManager | None = None,
) -> tuple[list[MultiModalGroup], _PendingRelease | None]:
"""Fetch multimodal embeddings with transparent cache-through.
Pipeline: check_cache → fetch misses from encode workers → update_cache.
When *cache* is ``None`` the cache steps are no-ops and all URLs go
straight to the encode workers.
For NIXL receivers the returned embeddings are zero-copy views. The
returned ``_PendingRelease`` must be released after consuming the
tensors.
"""
results: list[MultiModalGroup | None] = [None] * len(image_urls)
to_fetch: list[tuple[int, str, str | None]] = []
# ── 1. Check cache (no-op when cache is None) ────────────────────
for idx, url in enumerate(image_urls):
if cache is not None:
key = get_embedding_hash(url)
cached = cache.get(key)
if cached is not None:
logger.debug(f"[{request_id}] Cache hit for URL index {idx}")
results[idx] = MultiModalGroup(
loaded_embedding=cached.tensor,
image_grid_thw=cached.image_grid_thw,
)
continue
else:
key = None
to_fetch.append((idx, url, key))
# ── 2. Fetch uncached from encode workers ────────────────────────
pending: _PendingRelease | None = None
if to_fetch:
if cache is not None:
logger.info(
f"[{request_id}] Cache miss for {len(to_fetch)}/{len(image_urls)} URLs, "
"fetching from encode workers"
)
miss_urls = [url for _, url, _ in to_fetch]
groups, pending = await _fetch_from_encode_workers(
encode_worker_client,
miss_urls,
request_id,
receiver,
)
# ── 3. Update cache (no-op when cache is None) ──────────────
for (idx, _url, key), group in zip(to_fetch, groups, strict=True):
if cache is not None and key is not None:
cache.set(
key,
CachedEmbedding(
tensor=group.loaded_embedding.clone(),
image_grid_thw=group.image_grid_thw,
),
)
results[idx] = group
else:
logger.info(f"[{request_id}] All {len(image_urls)} URLs served from cache")
return [r for r in results if r is not None], pending
# ── Public API (single entry point) ─────────────────────────────────
async def load_multimodal_embeddings(
encode_worker_client: Client,
image_urls: list[str],
request_id: str,
receiver: AbstractEmbeddingReceiver,
*,
model: str,
embeddings_dtype: torch.dtype,
cache: MultimodalEmbeddingCacheManager | None = None,
) -> Dict[str, Any]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``.
Full pipeline:
cache check → remote fetch → cache update → accumulate → release NIXL buffers.
Returns a dict suitable for passing to ``TokensPrompt(multi_modal_data=...)``.
"""
groups, pending = await _fetch_embeddings(
encode_worker_client,
image_urls,
request_id,
receiver,
cache=cache,
)
multi_modal_data: Dict[str, Any] = defaultdict(list)
for group in groups:
_accumulate_embeddings(
multi_modal_data,
model,
embeddings_dtype,
group.loaded_embedding,
group.image_grid_thw,
)
if pending is not None:
# Multi-image: torch.cat in _accumulate_embeddings already created
# owned tensors. Single-image: the data is still a view into the
# NIXL buffer, so we must clone before releasing.
if len(groups) == 1:
_ensure_owned_tensors(multi_modal_data)
pending.release_all()
return multi_modal_data
......@@ -18,6 +18,7 @@ import json
from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec
import torch
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
......@@ -171,6 +172,7 @@ class MultiModalGroup(BaseModel):
Union[Tuple[int, int, int], Tuple[int, int, int, int]]
] = None
serialized_request: Optional[TransferRequest] = None
loaded_embedding: Optional[torch.Tensor] = Field(default=None, exclude=True)
class vLLMMultimodalRequest(vLLMGenerateRequest):
......
......@@ -4,17 +4,17 @@
"""Unit tests for MultimodalPDWorkerHandler."""
import json
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.vllm.multimodal_handlers import multimodal_pd_worker_handler as mod
from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup,
MultiModalInput,
MyRequestOutput,
PatchedTokensPrompt,
vLLMMultimodalRequest,
......@@ -128,26 +128,98 @@ class TestInit:
assert handler.embedding_cache_manager._capacity_bytes == expected_bytes
class TestBuildRequestFromFrontend:
class TestParseFrontendRequest:
def test_extracts_token_ids_and_sampling_params(self):
"""Parses token_ids and sampling_params from raw frontend dict."""
handler = _make_handler()
handler.default_sampling_params = {}
raw = _make_raw_frontend_request()
request, image_urls = handler._parse_frontend_request(raw)
assert request.engine_prompt["prompt_token_ids"] == [1, 2, 3]
assert image_urls == []
def test_extracts_image_urls(self):
"""Extracts image URLs from multi_modal_data."""
handler = _make_handler()
handler.default_sampling_params = {}
raw = _make_raw_frontend_request(image_urls=["http://a.png", "http://b.png"])
request, image_urls = handler._parse_frontend_request(raw)
assert image_urls == ["http://a.png", "http://b.png"]
class TestLoadMultimodalData:
@pytest.mark.asyncio
async def test_no_encode_client_returns_empty(self):
"""Without encode client -> returns empty dict."""
handler = _make_handler(encode_worker_client=None)
mm_data = await handler._load_multimodal_data(["http://img.png"], "req-1")
assert len(mm_data) == 0
@pytest.mark.asyncio
async def test_with_encode_worker_calls_fetch(self):
"""With encode client -> delegates to fetch_embeddings_from_encode_workers."""
async def test_no_images_returns_empty(self):
"""With encode client but no images -> returns empty dict."""
handler = _make_handler(encode_worker_client=MagicMock())
mm_data = await handler._load_multimodal_data([], "req-1")
assert len(mm_data) == 0
@pytest.mark.asyncio
async def test_delegates_to_load_multimodal_embeddings(self):
"""With encode client -> delegates to load_multimodal_embeddings."""
mock_client = MagicMock()
handler = _make_handler(encode_worker_client=mock_client)
handler.default_sampling_params = {}
fake_group = MultiModalGroup(multimodal_input=MultiModalInput())
fake_mm_data = defaultdict(list, {"image": torch.randn(1, 10)})
with patch.object(
mod,
"load_multimodal_embeddings",
new_callable=AsyncMock,
return_value=fake_mm_data,
) as mock_load:
result = await handler._load_multimodal_data(["http://img.png"], "req-1")
mock_load.assert_awaited_once()
assert result is fake_mm_data
@pytest.mark.asyncio
async def test_passes_cache_to_load_multimodal_embeddings(self):
"""With cache enabled -> passes cache manager kwarg."""
mock_client = MagicMock()
config = _make_config(multimodal_embedding_cache_capacity_gb=1.0)
handler = _make_handler(config=config, encode_worker_client=mock_client)
with patch.object(
mod,
"fetch_embeddings_from_encode_workers",
"load_multimodal_embeddings",
new_callable=AsyncMock,
return_value=[fake_group],
) as mock_fetch:
raw = _make_raw_frontend_request(image_urls=["http://img.png"])
result = await handler._build_request_from_frontend(raw)
return_value=defaultdict(list),
) as mock_load:
await handler._load_multimodal_data(["http://img.png"], "req-1")
mock_fetch.assert_awaited_once()
assert result.multimodal_inputs == [fake_group]
mock_load.assert_awaited_once()
assert mock_load.call_args.kwargs["cache"] is handler.embedding_cache_manager
@pytest.mark.asyncio
async def test_passes_model_and_dtype(self):
"""Model name and embeddings dtype are forwarded."""
mock_client = MagicMock()
handler = _make_handler(encode_worker_client=mock_client)
with patch.object(
mod,
"load_multimodal_embeddings",
new_callable=AsyncMock,
return_value=defaultdict(list),
) as mock_load:
await handler._load_multimodal_data(["http://img.png"], "req-1")
assert mock_load.call_args.kwargs["model"] == handler.config.model
assert (
mock_load.call_args.kwargs["embeddings_dtype"] == handler.EMBEDDINGS_DTYPE
)
class TestGenerateAgg:
......@@ -158,7 +230,6 @@ class TestGenerateAgg:
request = _make_vllm_request()
engine_resp = _make_engine_response()
# Add a proper output so we exercise the happy path
output = MagicMock()
output.token_ids = [10, 11]
output.finish_reason = "stop"
......@@ -172,7 +243,7 @@ class TestGenerateAgg:
handler.engine_client.generate = fake_generate
chunks = []
async for chunk in handler._generate_agg(request, {"image": []}, []):
async for chunk in handler._generate_agg(request, {"image": []}):
chunks.append(chunk)
assert len(chunks) == 1
......@@ -189,7 +260,6 @@ class TestGenerateDisagg:
handler = _make_handler(config=config, decode_worker_client=decode_client)
handler.engine_client = MagicMock()
# Mock prefill engine response
prefill_resp = _make_engine_response()
prefill_resp.kv_transfer_params = {"block_ids": [0, 1]}
......@@ -198,7 +268,6 @@ class TestGenerateDisagg:
handler.engine_client.generate = fake_generate
# Mock decode worker response
decode_output = MyRequestOutput(
request_id="req-1",
prompt="test",
......@@ -220,7 +289,7 @@ class TestGenerateDisagg:
request = _make_vllm_request()
chunks = []
async for chunk in handler._generate_disagg(request, {"image": []}, []):
async for chunk in handler._generate_disagg(request, {"image": []}):
chunks.append(chunk)
assert len(chunks) == 1
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for load_multimodal_embeddings in prefill_worker_utils."""
from unittest.mock import AsyncMock, patch
import pytest
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.vllm.multimodal_utils import prefill_worker_utils as mod
from dynamo.vllm.multimodal_utils.protocol import MultiModalGroup, MultiModalInput
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.multimodal,
]
MODEL = "test-model"
DTYPE = torch.float16
class TestLoadMultimodalEmbeddings:
@pytest.mark.asyncio
async def test_all_cached(self):
"""All URLs cached -> no encode worker call, returns accumulated mm_data."""
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(1, 10, dtype=DTYPE)
grid = [[1, 2, 3]]
url = "http://img1.png"
key = mod.get_embedding_hash(url)
cache.set(key, CachedEmbedding(tensor=tensor, image_grid_thw=grid))
with patch.object(
mod,
"_fetch_from_encode_workers",
new_callable=AsyncMock,
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
[url],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=cache,
)
mock_fetch.assert_not_awaited()
assert torch.equal(mm_data["image"], tensor)
@pytest.mark.asyncio
async def test_all_uncached_with_cache(self):
"""All URLs uncached with cache -> encode worker call, results cached."""
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
url = "http://img1.png"
tensor = torch.randn(1, 10, dtype=DTYPE)
fake_group = MultiModalGroup(
multimodal_input=MultiModalInput(),
image_grid_thw=[[1, 2, 3]],
loaded_embedding=tensor,
)
with patch.object(
mod,
"_fetch_from_encode_workers",
new_callable=AsyncMock,
return_value=([fake_group], None),
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
[url],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=cache,
)
mock_fetch.assert_awaited_once()
assert torch.equal(mm_data["image"], tensor)
key = mod.get_embedding_hash(url)
cached = cache.get(key)
assert cached is not None
assert torch.equal(cached.tensor, tensor)
@pytest.mark.asyncio
async def test_no_cache(self):
"""Without cache -> all URLs go to encode workers."""
url = "http://img1.png"
tensor = torch.randn(1, 10, dtype=DTYPE)
fake_group = MultiModalGroup(
multimodal_input=MultiModalInput(),
loaded_embedding=tensor,
)
with patch.object(
mod,
"_fetch_from_encode_workers",
new_callable=AsyncMock,
return_value=([fake_group], None),
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
[url],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=None,
)
mock_fetch.assert_awaited_once()
assert torch.equal(mm_data["image"], tensor)
@pytest.mark.asyncio
async def test_mixed_cache(self):
"""Mixed cache hits/misses -> only misses sent to encode workers."""
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
url_cached = "http://cached.png"
url_miss = "http://miss.png"
cached_tensor = torch.randn(1, 10, dtype=DTYPE)
miss_tensor = torch.randn(1, 10, dtype=DTYPE)
key = mod.get_embedding_hash(url_cached)
cache.set(key, CachedEmbedding(tensor=cached_tensor, image_grid_thw=None))
fake_group = MultiModalGroup(
multimodal_input=MultiModalInput(),
image_grid_thw=None,
loaded_embedding=miss_tensor,
)
with patch.object(
mod,
"_fetch_from_encode_workers",
new_callable=AsyncMock,
return_value=([fake_group], None),
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
[url_cached, url_miss],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=cache,
)
mock_fetch.assert_awaited_once()
call_args = mock_fetch.call_args
assert call_args[0][1] == [url_miss]
expected = torch.cat((cached_tensor, miss_tensor))
assert torch.equal(mm_data["image"], expected)
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
SINGLE_GPU=false
# Parse command line arguments
# All extra arguments are passed through to the PD worker's dynamo.vllm
# (which routes them to Dynamo or vLLM as appropriate).
EXTRA_PD_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--single-gpu)
SINGLE_GPU=true
shift
;;
-h|--help)
echo "Usage: $0 [OPTIONS] [EXTRA_ARGS...]"
echo ""
echo "Disaggregated multimodal serving with separate Encode and aggregated PD worker"
echo ""
echo "Options:"
echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)"
echo " LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates"
echo " --single-gpu Run encode and PD workers on the same GPU (for small models, e.g. 2B)"
echo " -h, --help Show this help message"
echo ""
echo "All additional arguments are passed through to the PD worker's dynamo.vllm."
echo "Dynamo args (e.g. --multimodal-embedding-cache-capacity-gb) and"
echo "vLLM engine args (e.g. --no-enable-prefix-caching) are automatically routed."
echo ""
echo "Examples:"
echo " $0 --model llava-hf/llava-1.5-7b-hf"
echo " $0 --model microsoft/Phi-3.5-vision-instruct"
echo " $0 --model Qwen/Qwen2.5-VL-7B-Instruct"
echo " $0 --no-enable-prefix-caching --multimodal-embedding-cache-capacity-gb 2"
echo " $0 --model Qwen/Qwen2-VL-2B-Instruct --single-gpu"
echo ""
exit 0
;;
*)
EXTRA_PD_ARGS+=("$1")
shift
;;
esac
done
PD_MAX_MODEL_LEN="16384"
echo "=================================================="
echo "Disaggregated Multimodal Serving (E + PD)"
echo "=================================================="
echo "Model: $MODEL_NAME"
echo "=================================================="
# Start frontend (no router mode)
echo "Starting frontend..."
python -m dynamo.frontend &
EXTRA_ARGS=""
# Embedding transfer: 1 = local file (safetensors), 0 = NIXL RDMA
export TRANSFER_LOCAL=${TRANSFER_LOCAL:-1}
# GPU assignments (override via environment variables)
if [[ "$SINGLE_GPU" == "true" ]]; then
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-0}
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-0}
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.4}
DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.4}
EXTRA_ARGS="--enforce-eager"
else
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-1}
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-2}
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.9}
DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.9}
fi
# Start encode worker
echo "Starting encode worker on GPU $DYN_ENCODE_WORKER_GPU (GPU mem: $DYN_ENCODE_GPU_MEM)..."
CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU \
python -m dynamo.vllm \
--multimodal-encode-worker \
--enable-multimodal \
--model "$MODEL_NAME" \
--gpu-memory-utilization "$DYN_ENCODE_GPU_MEM" \
$EXTRA_ARGS &
# Start PD worker (aggregated prefill+decode, routes to encoder for embeddings)
echo "Starting PD worker on GPU $DYN_PD_WORKER_GPU (GPU mem: $DYN_PD_GPU_MEM)..."
CUDA_VISIBLE_DEVICES=$DYN_PD_WORKER_GPU \
python -m dynamo.vllm \
--route-to-encoder \
--multimodal-worker \
--enable-multimodal \
--enable-mm-embeds \
--model "$MODEL_NAME" \
--max-model-len "$PD_MAX_MODEL_LEN" \
--gpu-memory-utilization "$DYN_PD_GPU_MEM" \
$EXTRA_ARGS \
"${EXTRA_PD_ARGS[@]}" &
echo "=================================================="
echo "All components started. Waiting for initialization..."
echo "=================================================="
# Wait for all background processes to complete
wait
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
from io import BytesIO
import pytest
from pytest_httpserver import HTTPServer
......@@ -17,6 +18,28 @@ MULTIMODAL_IMG_PATH = os.path.join(
MULTIMODAL_IMG_URL = f"http://localhost:{IMAGE_SERVER_PORT}/llm-graphic.png"
# Git LFS pointer files start with "version "; serve a real PNG when the asset is not pulled.
def get_multimodal_test_image_bytes() -> bytes:
"""Return valid PNG bytes for /llm-graphic.png (file or minimal fallback)."""
if os.path.isfile(MULTIMODAL_IMG_PATH):
with open(MULTIMODAL_IMG_PATH, "rb") as f:
data = f.read()
if not data.startswith(b"version "):
# GitHub path
return data
# Local path where we cannot retrieve the above .png file
# Lazy import so conftest loads in environments that don't have Pillow (e.g. pre-commit).
from PIL import Image
buf = BytesIO()
# TODO: differerent models / tests may expect different colors. Need to reconcicle
# code to support all cases locally if needed.
Image.new("RGB", (2, 2), color="green").save(buf, format="PNG")
return buf.getvalue()
@pytest.fixture(scope="session")
def httpserver_listen_address():
return ("127.0.0.1", IMAGE_SERVER_PORT)
......@@ -33,15 +56,14 @@ def image_server(httpserver: HTTPServer):
Currently serves:
- /llm-graphic.png - LLM diagram image for multimodal tests
(or a minimal PNG if the file is a Git LFS pointer / not pulled)
Usage:
def test_multimodal(image_server):
url = "http://localhost:8765/llm-graphic.png"
# ... use url in your test payload
"""
# Load LLM graphic image from shared test data
with open(MULTIMODAL_IMG_PATH, "rb") as f:
image_data = f.read()
image_data = get_multimodal_test_image_bytes()
# Configure server endpoint
httpserver.expect_request("/llm-graphic.png").respond_with_data(
......
......@@ -16,7 +16,7 @@ from tests.serve.common import (
params_with_model_mark,
run_serve_deployment,
)
from tests.serve.conftest import MULTIMODAL_IMG_PATH, MULTIMODAL_IMG_URL
from tests.serve.conftest import MULTIMODAL_IMG_URL, get_multimodal_test_image_bytes
from tests.serve.lora_utils import MinioLoraConfig
from tests.utils.constants import DefaultPort
from tests.utils.engine_process import EngineConfig
......@@ -276,37 +276,34 @@ vllm_configs = {
completion_payload_default(),
],
),
# The original script is misleading agg_multimodal_epd.sh is actually a disagg
# case which uses disgg encoder. We are bringing this test back shortly
# TODO(qiwa): enable this in https://github.com/ai-dynamo/dynamo/pull/6061/
# "multimodal_agg_qwen2vl_2b_epd": VLLMConfig(
# name="multimodal_agg_qwen2vl_2b_epd",
# directory=vllm_dir,
# script_name="agg_multimodal_epd.sh",
# marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
# model="Qwen/Qwen2-VL-2B-Instruct",
# script_args=["--model", "Qwen/Qwen2-VL-2B-Instruct", "--single-gpu"],
# request_payloads=[
# chat_payload(
# [
# {
# "type": "text",
# "text": "What colors are in the following image? Respond only with the colors.",
# },
# {
# "type": "image_url",
# "image_url": {"url": MULTIMODAL_IMG_URL},
# },
# ],
# repeat_count=1,
# # With proper prompt templating, the model actually only returns "green",
# # verified behavior with native vLLM.
# expected_response=["green"],
# temperature=0.0,
# max_tokens=100,
# )
# ],
# ),
"multimodal_disagg_qwen2vl_2b_e_pd": VLLMConfig(
name="multimodal_disagg_qwen2vl_2b_e_pd",
directory=vllm_dir,
script_name="disagg_multimodal_e_pd.sh",
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
model="Qwen/Qwen2-VL-2B-Instruct",
script_args=["--model", "Qwen/Qwen2-VL-2B-Instruct", "--single-gpu"],
request_payloads=[
chat_payload(
[
{
"type": "text",
"text": "What colors are in the following image? Respond only with the colors.",
},
{
"type": "image_url",
"image_url": {"url": MULTIMODAL_IMG_URL},
},
],
repeat_count=1,
# With proper prompt templating, the model actually only returns "green",
# verified behavior with native vLLM.
expected_response=["green"],
temperature=0.0,
max_tokens=100,
)
],
),
"multimodal_agg_frontend_decoding": VLLMConfig(
name="multimodal_agg_frontend_decoding",
directory=vllm_dir,
......@@ -755,9 +752,8 @@ def test_multimodal_b64(
This test is separate because it loads the required image at runtime
(not collection time), ensuring it only fails when actually executed.
"""
# Load B64 image at test execution time
with open(MULTIMODAL_IMG_PATH, "rb") as f:
b64_img = base64.b64encode(f.read()).decode()
# Load B64 image at test execution time (uses real PNG even if MULTIMODAL_IMG is LFS pointer)
b64_img = base64.b64encode(get_multimodal_test_image_bytes()).decode()
# Create payload with B64 image
b64_payload = chat_payload(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment