Unverified Commit c82fe888 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat: add embedding cache to pd worker (#6061)

parent ebc61637
...@@ -19,13 +19,18 @@ Usage: ...@@ -19,13 +19,18 @@ Usage:
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Optional from typing import NamedTuple, Optional
import torch import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CachedEmbedding(NamedTuple):
tensor: torch.Tensor
image_grid_thw: list | None = None
class MultimodalEmbeddingCacheManager: class MultimodalEmbeddingCacheManager:
""" """
LRU cache for encoder embeddings. LRU cache for encoder embeddings.
...@@ -47,7 +52,7 @@ class MultimodalEmbeddingCacheManager: ...@@ -47,7 +52,7 @@ class MultimodalEmbeddingCacheManager:
Args: Args:
capacity_bytes: Maximum cache capacity in bytes. capacity_bytes: Maximum cache capacity in bytes.
""" """
self._cache: OrderedDict[str, torch.Tensor] = OrderedDict() self._cache: OrderedDict[str, CachedEmbedding] = OrderedDict()
self._capacity_bytes = capacity_bytes self._capacity_bytes = capacity_bytes
self._current_bytes = 0 self._current_bytes = 0
...@@ -77,9 +82,9 @@ class MultimodalEmbeddingCacheManager: ...@@ -77,9 +82,9 @@ class MultimodalEmbeddingCacheManager:
), "Tensor must be contiguous for accurate size calculation" ), "Tensor must be contiguous for accurate size calculation"
return tensor.element_size() * tensor.numel() return tensor.element_size() * tensor.numel()
def get(self, key: str) -> Optional[torch.Tensor]: def get(self, key: str) -> Optional[CachedEmbedding]:
""" """
Get a tensor from the cache. Get a cached embedding from the cache.
If found, the entry is moved to the end (most recently used). If found, the entry is moved to the end (most recently used).
...@@ -87,7 +92,7 @@ class MultimodalEmbeddingCacheManager: ...@@ -87,7 +92,7 @@ class MultimodalEmbeddingCacheManager:
key: Cache key (typically content hash). key: Cache key (typically content hash).
Returns: Returns:
The cached tensor, or None if not found. The cached embedding, or None if not found.
""" """
if key not in self._cache: if key not in self._cache:
self._misses += 1 self._misses += 1
...@@ -98,22 +103,22 @@ class MultimodalEmbeddingCacheManager: ...@@ -98,22 +103,22 @@ class MultimodalEmbeddingCacheManager:
self._hits += 1 self._hits += 1
return self._cache[key] return self._cache[key]
def set(self, key: str, tensor: torch.Tensor) -> bool: def set(self, key: str, entry: CachedEmbedding) -> bool:
""" """
Store a tensor in the cache. Store a cached embedding in the cache.
If the key already exists, the old value is replaced. If the key already exists, the old value is replaced.
If adding the tensor would exceed capacity, LRU entries are evicted. If adding the entry would exceed capacity, LRU entries are evicted.
If the tensor itself is larger than capacity, it is not stored. If the tensor itself is larger than capacity, it is not stored.
Args: Args:
key: Cache key (typically content hash). key: Cache key (typically content hash).
tensor: Tensor to cache. entry: CachedEmbedding to cache.
Returns: Returns:
True if the tensor was stored, False if it was too large. True if the entry was stored, False if it was too large.
""" """
size = self._tensor_size(tensor) size = self._tensor_size(entry.tensor)
# Don't cache if single tensor exceeds capacity # Don't cache if single tensor exceeds capacity
if size > self._capacity_bytes: if size > self._capacity_bytes:
...@@ -125,20 +130,20 @@ class MultimodalEmbeddingCacheManager: ...@@ -125,20 +130,20 @@ class MultimodalEmbeddingCacheManager:
# If key exists, remove old entry first # If key exists, remove old entry first
if key in self._cache: if key in self._cache:
old_tensor = self._cache.pop(key) old_entry = self._cache.pop(key)
self._current_bytes -= self._tensor_size(old_tensor) self._current_bytes -= self._tensor_size(old_entry.tensor)
# Evict LRU entries until we have space # Evict LRU entries until we have space
while self._current_bytes + size > self._capacity_bytes and self._cache: while self._current_bytes + size > self._capacity_bytes and self._cache:
evicted_key, evicted_tensor = self._cache.popitem(last=False) evicted_key, evicted_entry = self._cache.popitem(last=False)
evicted_size = self._tensor_size(evicted_tensor) evicted_size = self._tensor_size(evicted_entry.tensor)
self._current_bytes -= evicted_size self._current_bytes -= evicted_size
logger.debug( logger.debug(
f"Evicted key={evicted_key[:16]}..., size={evicted_size / 1024**2:.2f}MB" f"Evicted key={evicted_key[:16]}..., size={evicted_size / 1024**2:.2f}MB"
) )
# Store new entry # Store new entry
self._cache[key] = tensor self._cache[key] = entry
self._current_bytes += size self._current_bytes += size
logger.debug( logger.debug(
......
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
import torch import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
...@@ -19,12 +20,25 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations: ...@@ -19,12 +20,25 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024) # 1MB cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024) # 1MB
tensor = torch.randn(100, 100) # ~40KB for float32 tensor = torch.randn(100, 100) # ~40KB for float32
result = cache.set("key1", tensor) result = cache.set("key1", CachedEmbedding(tensor))
assert result is True assert result is True
retrieved = cache.get("key1") retrieved = cache.get("key1")
assert retrieved is not None assert retrieved is not None
assert torch.equal(retrieved, tensor) assert torch.equal(retrieved.tensor, tensor)
assert retrieved.image_grid_thw is None
def test_set_and_get_with_grid(self):
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(100, 100)
grid = [[1, 2, 3]]
cache.set("key1", CachedEmbedding(tensor, grid))
retrieved = cache.get("key1")
assert retrieved is not None
assert torch.equal(retrieved.tensor, tensor)
assert retrieved.image_grid_thw == grid
def test_get_nonexistent_key(self): def test_get_nonexistent_key(self):
"""Test get returns None for nonexistent key.""" """Test get returns None for nonexistent key."""
...@@ -39,11 +53,11 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations: ...@@ -39,11 +53,11 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
tensor1 = torch.randn(10, 10) tensor1 = torch.randn(10, 10)
tensor2 = torch.randn(10, 10) tensor2 = torch.randn(10, 10)
cache.set("key1", tensor1) cache.set("key1", CachedEmbedding(tensor1))
cache.set("key1", tensor2) cache.set("key1", CachedEmbedding(tensor2))
retrieved = cache.get("key1") retrieved = cache.get("key1")
assert torch.equal(retrieved, tensor2) assert torch.equal(retrieved.tensor, tensor2)
assert cache.stats["entries"] == 1 assert cache.stats["entries"] == 1
...@@ -61,11 +75,11 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction: ...@@ -61,11 +75,11 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
t2 = torch.randn(10, 10) t2 = torch.randn(10, 10)
t3 = torch.randn(10, 10) t3 = torch.randn(10, 10)
cache.set("key1", t1) cache.set("key1", CachedEmbedding(t1))
cache.set("key2", t2) cache.set("key2", CachedEmbedding(t2))
# Adding third should evict first (LRU) # Adding third should evict first (LRU)
cache.set("key3", t3) cache.set("key3", CachedEmbedding(t3))
assert cache.get("key1") is None # Evicted assert cache.get("key1") is None # Evicted
assert cache.get("key2") is not None assert cache.get("key2") is not None
...@@ -81,14 +95,14 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction: ...@@ -81,14 +95,14 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
t2 = torch.randn(10, 10) t2 = torch.randn(10, 10)
t3 = torch.randn(10, 10) t3 = torch.randn(10, 10)
cache.set("key1", t1) cache.set("key1", CachedEmbedding(t1))
cache.set("key2", t2) cache.set("key2", CachedEmbedding(t2))
# Access key1, making key2 the LRU # Access key1, making key2 the LRU
cache.get("key1") cache.get("key1")
# Adding third should evict key2 (now LRU) # Adding third should evict key2 (now LRU)
cache.set("key3", t3) cache.set("key3", CachedEmbedding(t3))
assert cache.get("key1") is not None # Not evicted (recently accessed) assert cache.get("key1") is not None # Not evicted (recently accessed)
assert cache.get("key2") is None # Evicted (LRU) assert cache.get("key2") is None # Evicted (LRU)
...@@ -99,7 +113,7 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction: ...@@ -99,7 +113,7 @@ class TestMultimodalEmbeddingCacheManagerLRUEviction:
cache = MultimodalEmbeddingCacheManager(capacity_bytes=100) # Very small cache = MultimodalEmbeddingCacheManager(capacity_bytes=100) # Very small
tensor = torch.randn(100, 100) # ~40KB, way larger than capacity tensor = torch.randn(100, 100) # ~40KB, way larger than capacity
result = cache.set("key1", tensor) result = cache.set("key1", CachedEmbedding(tensor))
assert result is False assert result is False
assert cache.get("key1") is None assert cache.get("key1") is None
...@@ -119,10 +133,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking: ...@@ -119,10 +133,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking:
expected_size_1 = t1.element_size() * t1.numel() expected_size_1 = t1.element_size() * t1.numel()
expected_size_2 = t2.element_size() * t2.numel() expected_size_2 = t2.element_size() * t2.numel()
cache.set("key1", t1) cache.set("key1", CachedEmbedding(t1))
assert cache.stats["current_bytes"] == expected_size_1 assert cache.stats["current_bytes"] == expected_size_1
cache.set("key2", t2) cache.set("key2", CachedEmbedding(t2))
assert cache.stats["current_bytes"] == expected_size_1 + expected_size_2 assert cache.stats["current_bytes"] == expected_size_1 + expected_size_2
def test_size_updated_on_overwrite(self): def test_size_updated_on_overwrite(self):
...@@ -132,10 +146,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking: ...@@ -132,10 +146,10 @@ class TestMultimodalEmbeddingCacheManagerSizeTracking:
small_tensor = torch.randn(10, 10) # 400 bytes small_tensor = torch.randn(10, 10) # 400 bytes
large_tensor = torch.randn(20, 20) # 1600 bytes large_tensor = torch.randn(20, 20) # 1600 bytes
cache.set("key1", small_tensor) cache.set("key1", CachedEmbedding(small_tensor))
initial_size = cache.stats["current_bytes"] initial_size = cache.stats["current_bytes"]
cache.set("key1", large_tensor) cache.set("key1", CachedEmbedding(large_tensor))
expected_size = large_tensor.element_size() * large_tensor.numel() expected_size = large_tensor.element_size() * large_tensor.numel()
assert cache.stats["current_bytes"] == expected_size assert cache.stats["current_bytes"] == expected_size
...@@ -150,7 +164,7 @@ class TestMultimodalEmbeddingCacheManagerStats: ...@@ -150,7 +164,7 @@ class TestMultimodalEmbeddingCacheManagerStats:
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
cache.set("key1", tensor) cache.set("key1", CachedEmbedding(tensor))
# Misses # Misses
cache.get("nonexistent1") cache.get("nonexistent1")
...@@ -170,7 +184,7 @@ class TestMultimodalEmbeddingCacheManagerStats: ...@@ -170,7 +184,7 @@ class TestMultimodalEmbeddingCacheManagerStats:
"""Test stats dictionary contains expected keys.""" """Test stats dictionary contains expected keys."""
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024) cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
cache.set("key1", tensor) cache.set("key1", CachedEmbedding(tensor))
stats = cache.stats stats = cache.stats
...@@ -190,10 +204,9 @@ class TestMultimodalEmbeddingCacheManagerStats: ...@@ -190,10 +204,9 @@ class TestMultimodalEmbeddingCacheManagerStats:
capacity = 1000 capacity = 1000
cache = MultimodalEmbeddingCacheManager(capacity_bytes=capacity) cache = MultimodalEmbeddingCacheManager(capacity_bytes=capacity)
# Create tensor of known size
# float32 = 4 bytes, so 25 elements = 100 bytes # float32 = 4 bytes, so 25 elements = 100 bytes
tensor = torch.zeros(25, dtype=torch.float32) tensor = torch.zeros(25, dtype=torch.float32)
cache.set("key1", tensor) cache.set("key1", CachedEmbedding(tensor))
stats = cache.stats stats = cache.stats
expected_utilization = 100 / capacity expected_utilization = 100 / capacity
...@@ -209,7 +222,7 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor: ...@@ -209,7 +222,7 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor:
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
assert tensor.is_contiguous() assert tensor.is_contiguous()
result = cache.set("key1", tensor) result = cache.set("key1", CachedEmbedding(tensor))
assert result is True assert result is True
def test_set_non_contiguous_tensor_raises(self): def test_set_non_contiguous_tensor_raises(self):
...@@ -221,4 +234,29 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor: ...@@ -221,4 +234,29 @@ class TestMultimodalEmbeddingCacheManagerContiguousTensor:
assert not tensor.is_contiguous() assert not tensor.is_contiguous()
with pytest.raises(AssertionError, match="Tensor must be contiguous"): with pytest.raises(AssertionError, match="Tensor must be contiguous"):
cache.set("key1", tensor) cache.set("key1", CachedEmbedding(tensor))
class TestCachedEmbeddingNamedTuple:
"""Tests for CachedEmbedding NamedTuple."""
def test_fields(self):
tensor = torch.randn(4, 4)
grid = [[1, 2, 3]]
entry = CachedEmbedding(tensor=tensor, image_grid_thw=grid)
assert torch.equal(entry.tensor, tensor)
assert entry.image_grid_thw == grid
def test_none_grid(self):
tensor = torch.randn(4, 4)
entry = CachedEmbedding(tensor=tensor, image_grid_thw=None)
assert entry.image_grid_thw is None
def test_unpacking(self):
tensor = torch.randn(4, 4)
grid = [[1, 2, 3]]
entry = CachedEmbedding(tensor=tensor, image_grid_thw=grid)
t, g = entry
assert torch.equal(t, tensor)
assert g == grid
...@@ -15,6 +15,7 @@ import torch ...@@ -15,6 +15,7 @@ import torch
from tensorrt_llm.llmapi import DisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
from dynamo.trtllm.multimodal.cuda_ipc import extract_embeddings_from_handles from dynamo.trtllm.multimodal.cuda_ipc import extract_embeddings_from_handles
...@@ -148,7 +149,7 @@ async def _fetch_embeddings_with_cache( ...@@ -148,7 +149,7 @@ async def _fetch_embeddings_with_cache(
cached = cache.get(url_hash) cached = cache.get(url_hash)
if cached is not None: if cached is not None:
logger.info(f"fetch_embeddings_with_cache: cache hit for URL: {url}") logger.info(f"fetch_embeddings_with_cache: cache hit for URL: {url}")
embeddings_with_index.append((i, cached)) embeddings_with_index.append((i, cached.tensor))
else: else:
logger.info(f"fetch_embeddings_with_cache: cache miss for URL: {url}") logger.info(f"fetch_embeddings_with_cache: cache miss for URL: {url}")
uncached_urls.append(url) uncached_urls.append(url)
...@@ -189,7 +190,7 @@ async def _fetch_embeddings_with_cache( ...@@ -189,7 +190,7 @@ async def _fetch_embeddings_with_cache(
# Cache new tensors (reuse hashes computed during cache lookup) # Cache new tensors (reuse hashes computed during cache lookup)
for url, url_hash, tensor in zip(uncached_urls, uncached_hashes, new_tensors): for url, url_hash, tensor in zip(uncached_urls, uncached_hashes, new_tensors):
cache.set(url_hash, tensor) cache.set(url_hash, CachedEmbedding(tensor=tensor))
logger.info( logger.info(
f"fetch_embeddings_with_cache: cached embedding for URL: {url}, shape: {tensor.shape}" f"fetch_embeddings_with_cache: cached embedding for URL: {url}, shape: {tensor.shape}"
) )
......
...@@ -18,6 +18,7 @@ if not torch.cuda.is_available(): ...@@ -18,6 +18,7 @@ if not torch.cuda.is_available():
from tensorrt_llm.llmapi import DisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
...@@ -76,7 +77,10 @@ class TestFetchEmbeddingsFromEncoder: ...@@ -76,7 +77,10 @@ class TestFetchEmbeddingsFromEncoder:
url1, url2 = "http://example.com/img1.jpg", "http://example.com/img2.jpg" url1, url2 = "http://example.com/img1.jpg", "http://example.com/img2.jpg"
embedding1, embedding2 = torch.ones(10, 256), torch.ones(10, 256) * 2 embedding1, embedding2 = torch.ones(10, 256), torch.ones(10, 256) * 2
encoder_cache.set(MultimodalHasher.hash_bytes(url1.encode()), embedding1) encoder_cache.set(
MultimodalHasher.hash_bytes(url1.encode()),
CachedEmbedding(tensor=embedding1),
)
request: dict[str, Any] = {"messages": []} request: dict[str, Any] = {"messages": []}
mock_client = create_mock_encode_client([embedding2]) mock_client = create_mock_encode_client([embedding2])
...@@ -98,8 +102,14 @@ class TestFetchEmbeddingsFromEncoder: ...@@ -98,8 +102,14 @@ class TestFetchEmbeddingsFromEncoder:
url1, url2 = "http://example.com/img1.jpg", "http://example.com/img2.jpg" url1, url2 = "http://example.com/img1.jpg", "http://example.com/img2.jpg"
embedding1, embedding2 = torch.ones(10, 256), torch.ones(10, 256) * 2 embedding1, embedding2 = torch.ones(10, 256), torch.ones(10, 256) * 2
encoder_cache.set(MultimodalHasher.hash_bytes(url1.encode()), embedding1) encoder_cache.set(
encoder_cache.set(MultimodalHasher.hash_bytes(url2.encode()), embedding2) MultimodalHasher.hash_bytes(url1.encode()),
CachedEmbedding(tensor=embedding1),
)
encoder_cache.set(
MultimodalHasher.hash_bytes(url2.encode()),
CachedEmbedding(tensor=embedding2),
)
async def should_not_call(req: dict[str, Any]) -> None: async def should_not_call(req: dict[str, Any]) -> None:
raise AssertionError("Should not be called") raise AssertionError("Should not be called")
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio
import copy import copy
import logging import logging
import os import os
...@@ -26,21 +25,16 @@ from dynamo.runtime import Client, Component, DistributedRuntime ...@@ -26,21 +25,16 @@ from dynamo.runtime import Client, Component, DistributedRuntime
from ..args import Config from ..args import Config
from ..handlers import BaseWorkerHandler, build_sampling_params from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import ( from ..multimodal_utils import (
MultiModalGroup,
MyRequestOutput, MyRequestOutput,
PatchedTokensPrompt, PatchedTokensPrompt,
vLLMMultimodalRequest, vLLMMultimodalRequest,
) )
from ..multimodal_utils.model import is_qwen_vl_model from ..multimodal_utils.model import is_qwen_vl_model
from ..multimodal_utils.prefill_worker_utils import ( from ..multimodal_utils.prefill_worker_utils import load_multimodal_embeddings
IMAGE_URL_KEY,
accumulate_embeddings,
fetch_embeddings_from_encode_workers,
load_embeddings,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url"
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1)) TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
...@@ -96,8 +90,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -96,8 +90,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
else: else:
self.EMBEDDINGS_DTYPE = torch.float16 self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently. # We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init # Note: This is synchronous initialization, async initialization happens in async_init
...@@ -120,19 +112,18 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -120,19 +112,18 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self._connector = connect.Connector() self._connector = connect.Connector()
logger.info("Multimodal PD Worker async initialization completed.") logger.info("Multimodal PD Worker async initialization completed.")
async def _build_request_from_frontend( def _parse_frontend_request(
self, raw_request: dict self, raw_request: dict
) -> vLLMMultimodalRequest: ) -> tuple[vLLMMultimodalRequest, list[str]]:
"""Convert a raw frontend dict into a vLLMMultimodalRequest. """Parse a raw frontend dict into a vLLMMultimodalRequest and image URLs.
When the PD worker is the direct frontend endpoint (no separate The Rust frontend sends a dict with ``token_ids`` and
processor), the Rust frontend sends a dict representation of PreprocessedRequest. ``multi_modal_data`` (containing image URLs). This method extracts
This method extracts image URLs, routes them to encode workers if available, those fields into a structured request. No I/O is performed here;
and assembles the standard request object that the rest of ``generate()`` expects. embedding fetching is handled separately by ``_load_multimodal_data``.
""" """
request_id = str(uuid.uuid4().hex) request_id = str(uuid.uuid4().hex)
# Extract image URLs from the raw frontend dict
image_urls: list[str] = [] image_urls: list[str] = []
mm_data = raw_request.get("multi_modal_data") mm_data = raw_request.get("multi_modal_data")
if mm_data is not None: if mm_data is not None:
...@@ -140,89 +131,44 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -140,89 +131,44 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if isinstance(item, dict) and "Url" in item: if isinstance(item, dict) and "Url" in item:
image_urls.append(item["Url"]) image_urls.append(item["Url"])
multimodal_groups: list[MultiModalGroup] = []
if self.encode_worker_client and image_urls:
multimodal_groups = await fetch_embeddings_from_encode_workers(
self.encode_worker_client,
image_urls,
request_id,
)
sampling_params = build_sampling_params( sampling_params = build_sampling_params(
raw_request, self.default_sampling_params raw_request, self.default_sampling_params
) )
return vLLMMultimodalRequest( request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt( engine_prompt=PatchedTokensPrompt(
prompt_token_ids=raw_request["token_ids"] prompt_token_ids=raw_request["token_ids"]
), ),
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
model=raw_request.get("model"), model=raw_request.get("model"),
multimodal_inputs=multimodal_groups,
) )
# ── Request parsing ──────────────────────────────────────────────── return request, image_urls
async def _parse_request(self, request) -> vLLMMultimodalRequest:
"""Normalize any incoming format into a validated vLLMMultimodalRequest.
Handles three input shapes:
1. Raw frontend dict (has ``token_ids`` + ``multi_modal_data``)
2. JSON string (from encode worker or other serializers)
3. Plain dict (Pydantic-compatible mapping)
"""
if isinstance(request, dict) and "token_ids" in request:
return await self._build_request_from_frontend(request)
if type(request) is vLLMMultimodalRequest:
return request
if type(request) is str:
return vLLMMultimodalRequest.model_validate_json(request)
return vLLMMultimodalRequest.model_validate(request)
# ── Multimodal data loading ────────────────────────────────────── # ── Multimodal data loading ──────────────────────────────────────
async def _load_multimodal_data( async def _load_multimodal_data(
self, request: vLLMMultimodalRequest self, image_urls: list[str], request_id: str
) -> tuple[dict[str, Any], list[int]]: ) -> dict[str, Any]:
"""Load pre-computed embeddings into an engine-ready dict. """Fetch embeddings from encode workers and load into an engine-ready dict.
Each ``MultiModalGroup`` carries embeddings from encode workers,
loaded via NIXL RDMA or local safetensors.
No-op when --route-to-encoder is not set. Returns an empty dict when no encode worker is configured or no images
are present.
""" """
multimodal_inputs: list[MultiModalGroup] = request.multimodal_inputs or [] if not self.encode_worker_client or not image_urls:
multi_modal_data: dict[str, Any] = defaultdict(list) return defaultdict(list)
task_lists = [ return await load_multimodal_embeddings(
asyncio.create_task( self.encode_worker_client, # type: ignore[arg-type]
load_embeddings( image_urls,
mi, request_id,
self.EMBEDDINGS_DTYPE,
self.EMBEDDINGS_DEVICE,
self.embedding_receiver, self.embedding_receiver,
) model=self.config.model,
) embeddings_dtype=self.EMBEDDINGS_DTYPE,
for mi in multimodal_inputs cache=self.embedding_cache_manager,
]
receiver_tensor_ids: list[int] = []
for task, mi in zip(task_lists, multimodal_inputs):
tensor_id, embeddings = await task
receiver_tensor_ids.append(tensor_id)
accumulate_embeddings(
multi_modal_data,
self.config.model,
self.EMBEDDINGS_DTYPE,
embeddings,
mi.image_grid_thw,
) )
return multi_modal_data, receiver_tensor_ids
# ── Request metadata finalization ──────────────────────────────── # ── Request metadata finalization ────────────────────────────────
def _finalize_request_metadata( def _finalize_request_metadata(
...@@ -230,14 +176,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -230,14 +176,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request: vLLMMultimodalRequest, request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any], multi_modal_data: dict[str, Any],
) -> None: ) -> None:
"""Attach model-specific metadata and strip heavy fields from request. """Attach model-specific metadata to the request for the decode worker.
For Qwen VL (mRoPE) models, captures image grid dimensions and For Qwen VL (mRoPE) models, captures image grid dimensions and
embedding shapes so the decode worker can reconstruct embedding shapes so the decode worker can reconstruct
``multi_modal_data`` consistently for multiple images. ``multi_modal_data`` consistently for multiple images.
Also clears ``multimodal_inputs`` — the raw embeddings / URLs are no
longer needed once ``multi_modal_data`` is built.
""" """
if is_qwen_vl_model(self.config.model) and isinstance( if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict multi_modal_data.get("image"), dict
...@@ -254,11 +197,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -254,11 +197,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if image_embeds is not None: if image_embeds is not None:
request.embeddings_shape = list(image_embeds.shape) request.embeddings_shape = list(image_embeds.shape)
# Use empty list instead of None to satisfy Pydantic validation logger.debug(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
# on decode worker after vllm upgrade.
request.multimodal_inputs = []
logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
logger.debug("Multimodal data keys: %s", list(multi_modal_data.keys())) logger.debug("Multimodal data keys: %s", list(multi_modal_data.keys()))
# ── Response serialization ─────────────────────────────────────── # ── Response serialization ───────────────────────────────────────
...@@ -318,7 +257,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -318,7 +257,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self, self,
request: vLLMMultimodalRequest, request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any], multi_modal_data: dict[str, Any],
received_tensor_ids: list[int],
): ):
"""Run prefill and decode on this worker (aggregated mode).""" """Run prefill and decode on this worker (aggregated mode)."""
lora_request = self._resolve_lora_request(request.model) lora_request = self._resolve_lora_request(request.model)
...@@ -332,9 +270,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -332,9 +270,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
lora_request=lora_request, lora_request=lora_request,
) )
for tensor_id in received_tensor_ids:
self.embedding_receiver.release_tensor(tensor_id)
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
async for response in gen: async for response in gen:
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}") logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
...@@ -351,7 +286,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -351,7 +286,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self, self,
request: vLLMMultimodalRequest, request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any], multi_modal_data: dict[str, Any],
received_tensor_ids: list[int],
): ):
"""Prefill locally, then forward to a remote decode worker.""" """Prefill locally, then forward to a remote decode worker."""
# Prepare prefill-only request # Prepare prefill-only request
...@@ -374,9 +308,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -374,9 +308,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
lora_request=lora_request, lora_request=lora_request,
) )
for tensor_id in received_tensor_ids:
self.embedding_receiver.release_tensor(tensor_id)
# Drain prefill generator (max_tokens=1, expect a single response) # Drain prefill generator (max_tokens=1, expect a single response)
async for prefill_response in gen: async for prefill_response in gen:
pass pass
...@@ -415,6 +346,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -415,6 +346,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
f"Forwarding disaggregated decode with LoRA '{request.model}' " f"Forwarding disaggregated decode with LoRA '{request.model}' "
f"— ensure the same adapter is loaded on the decode worker." f"— ensure the same adapter is loaded on the decode worker."
) )
async for ( async for (
decode_response decode_response
) in await self.decode_worker_client.round_robin( # type: ignore[union-attr] ) in await self.decode_worker_client.round_robin( # type: ignore[union-attr]
...@@ -425,30 +357,19 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -425,30 +357,19 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# ── Public entry point ─────────────────────────────────────────── # ── Public entry point ───────────────────────────────────────────
async def generate(self, request, context): async def generate(self, raw_request: dict, context):
"""Parse the request, load multimodal data, and run inference.""" """Parse the request, load multimodal data, and run inference."""
logger.debug(f"Got raw request: {request}") request, image_urls = self._parse_frontend_request(raw_request)
request = await self._parse_request(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
multi_modal_data, received_tensor_ids = await self._load_multimodal_data( multi_modal_data = await self._load_multimodal_data(
request image_urls, request.request_id
) )
self._finalize_request_metadata(request, multi_modal_data) self._finalize_request_metadata(request, multi_modal_data)
logger.info(
f"Prepared multimodal data size: {len(multi_modal_data.get('image', []))}"
)
logger.debug(f"{multi_modal_data}")
if self.enable_disagg and self.decode_worker_client: if self.enable_disagg and self.decode_worker_client:
async for chunk in self._generate_disagg( async for chunk in self._generate_disagg(request, multi_modal_data):
request, multi_modal_data, received_tensor_ids
):
yield chunk yield chunk
else: else:
async for chunk in self._generate_agg( async for chunk in self._generate_agg(request, multi_modal_data):
request, multi_modal_data, received_tensor_ids
):
yield chunk yield chunk
...@@ -19,11 +19,7 @@ from dynamo.vllm.multimodal_utils.model import ( ...@@ -19,11 +19,7 @@ from dynamo.vllm.multimodal_utils.model import (
construct_mm_data, construct_mm_data,
load_vision_model, load_vision_model,
) )
from dynamo.vllm.multimodal_utils.prefill_worker_utils import ( from dynamo.vllm.multimodal_utils.prefill_worker_utils import load_multimodal_embeddings
accumulate_embeddings,
fetch_embeddings_from_encode_workers,
load_embeddings,
)
from dynamo.vllm.multimodal_utils.protocol import ( from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup, MultiModalGroup,
MultiModalInput, MultiModalInput,
...@@ -52,7 +48,5 @@ __all__ = [ ...@@ -52,7 +48,5 @@ __all__ = [
"MultiModalRequest", "MultiModalRequest",
"MyRequestOutput", "MyRequestOutput",
"vLLMMultimodalRequest", "vLLMMultimodalRequest",
"accumulate_embeddings", "load_multimodal_embeddings",
"fetch_embeddings_from_encode_workers",
"load_embeddings",
] ]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio
import logging import logging
import os import os
from collections import defaultdict
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from vllm.sampling_params import SamplingParams as VllmSamplingParams from vllm.sampling_params import SamplingParams as VllmSamplingParams
from dynamo.common.multimodal.embedding_transfer import AbstractEmbeddingReceiver from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
LocalEmbeddingReceiver,
)
from dynamo.runtime import Client from dynamo.runtime import Client
from .encode_utils import get_embedding_hash
from .model import construct_mm_data from .model import construct_mm_data
from .protocol import ( from .protocol import (
MultiModalGroup, MultiModalGroup,
...@@ -21,39 +31,37 @@ from .protocol import ( ...@@ -21,39 +31,37 @@ from .protocol import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url" SPLIT_ENCODE = int(os.getenv("DYN_SPLIT_ENCODE", 1))
VIDEO_URL_KEY = "video_url"
# Whether to split the multimodal items into smaller batches for encoding. This can help if multimodal items can be speed up
# by separately encodeded with multiple workers.
# Need to experiment with this setting to see if it brings benefits when concurrency > encoder count.
SPLIT_ENCODE = int(os.getenv("SPLIT_ENCODE", 1))
# ── Internal helpers (all underscore-prefixed) ───────────────────────
async def load_embeddings(
mi: MultiModalGroup, class _PendingRelease:
_embeddings_dtype: torch.dtype, """Tracks NIXL tensor buffers that should be released after consumption.
_embeddings_device: str,
receiver: AbstractEmbeddingReceiver, For NIXL receivers, embeddings are views into pre-allocated reusable
) -> tuple[int, torch.Tensor]: buffers. Instead of cloning each embedding eagerly, we defer the
"""Load pre-computed embedding tensor via local safetensors or NIXL RDMA. release until the caller has consumed the tensors (e.g. via
``_accumulate_embeddings`` which copies data through ``torch.cat``).
Args:
mi: A single MultiModalGroup whose ``serialized_request`` field
contains either a local file path or NIXL RDMA metadata.
embeddings_dtype: Torch dtype for the tensor (used for RDMA path).
embeddings_device: Device string for the tensor (used for RDMA path).
receiver: AbstractEmbeddingReceiver for tensor reads.
Returns:
A tuple of (tensor_id, embeddings), where tensor_id is an integer identifier for the loaded tensor (used for later release),
and the embeddings tensor loaded into CPU memory.
""" """
tensor_id, embeddings = await receiver.receive_embeddings(mi.serialized_request)
return tensor_id, embeddings
__slots__ = ("_receiver", "_tensor_ids")
def __init__(self, receiver: AbstractEmbeddingReceiver):
self._receiver = receiver
self._tensor_ids: List[int] = []
def track(self, tensor_id: int) -> None:
self._tensor_ids.append(tensor_id)
def accumulate_embeddings( def release_all(self) -> None:
for tid in self._tensor_ids:
self._receiver.release_tensor(tid)
self._tensor_ids.clear()
def _accumulate_embeddings(
multi_modal_data: Dict[str, Any], multi_modal_data: Dict[str, Any],
model: str, model: str,
embeddings_dtype: torch.dtype, embeddings_dtype: torch.dtype,
...@@ -113,16 +121,32 @@ def accumulate_embeddings( ...@@ -113,16 +121,32 @@ def accumulate_embeddings(
) )
async def fetch_embeddings_from_encode_workers( def _ensure_owned_tensors(multi_modal_data: Dict[str, Any]) -> None:
"""Clone tensor views so NIXL buffers can be safely released.
Only needed for single-image; multi-image goes through torch.cat
which already produces owned tensors.
"""
img = multi_modal_data.get("image")
if isinstance(img, dict):
for k, v in img.items():
if isinstance(v, torch.Tensor):
img[k] = v.clone()
elif isinstance(img, torch.Tensor):
multi_modal_data["image"] = img.clone()
async def _fetch_from_encode_workers(
encode_worker_client: Client, encode_worker_client: Client,
image_urls: List[str], image_urls: List[str],
request_id: str, request_id: str,
) -> List[MultiModalGroup]: receiver: AbstractEmbeddingReceiver,
"""Fan out image URLs to encode workers and collect embedding results. ) -> tuple[List[MultiModalGroup], _PendingRelease | None]:
"""Fan out image URLs to encode workers, load embeddings, and return ready groups.
Splits image URLs into batches based on available encode worker count, For NIXL receivers the returned embeddings are zero-copy views into
dispatches via round-robin, and collects the resulting MultiModalGroups pre-allocated buffers. The returned ``_PendingRelease`` must be
containing pre-computed embeddings. released after the tensors have been consumed.
""" """
encode_worker_count = len(encode_worker_client.instance_ids()) encode_worker_count = len(encode_worker_client.instance_ids())
if encode_worker_count == 0: if encode_worker_count == 0:
...@@ -156,7 +180,6 @@ async def fetch_embeddings_from_encode_workers( ...@@ -156,7 +180,6 @@ async def fetch_embeddings_from_encode_workers(
) )
batch = [] batch = []
# Flush remaining
if batch: if batch:
encode_request.multimodal_inputs = batch encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json() payload = encode_request.model_dump_json()
...@@ -164,7 +187,6 @@ async def fetch_embeddings_from_encode_workers( ...@@ -164,7 +187,6 @@ async def fetch_embeddings_from_encode_workers(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type] await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
) )
# Collect results
multimodal_groups: List[MultiModalGroup] = [] multimodal_groups: List[MultiModalGroup] = []
for stream in encode_response_streams: for stream in encode_response_streams:
async for response in stream: async for response in stream:
...@@ -173,4 +195,135 @@ async def fetch_embeddings_from_encode_workers( ...@@ -173,4 +195,135 @@ async def fetch_embeddings_from_encode_workers(
if output.multimodal_inputs: if output.multimodal_inputs:
multimodal_groups.extend(output.multimodal_inputs) multimodal_groups.extend(output.multimodal_inputs)
return multimodal_groups tasks = [
asyncio.create_task(receiver.receive_embeddings(group.serialized_request))
for group in multimodal_groups
]
loaded = await asyncio.gather(*tasks)
is_local = isinstance(receiver, LocalEmbeddingReceiver)
pending: _PendingRelease | None = None if is_local else _PendingRelease(receiver)
for group, (tensor_id, embedding) in zip(multimodal_groups, loaded, strict=True):
group.loaded_embedding = embedding
if pending is not None:
pending.track(tensor_id)
return multimodal_groups, pending
async def _fetch_embeddings(
encode_worker_client: Client,
image_urls: list[str],
request_id: str,
receiver: AbstractEmbeddingReceiver,
cache: MultimodalEmbeddingCacheManager | None = None,
) -> tuple[list[MultiModalGroup], _PendingRelease | None]:
"""Fetch multimodal embeddings with transparent cache-through.
Pipeline: check_cache → fetch misses from encode workers → update_cache.
When *cache* is ``None`` the cache steps are no-ops and all URLs go
straight to the encode workers.
For NIXL receivers the returned embeddings are zero-copy views. The
returned ``_PendingRelease`` must be released after consuming the
tensors.
"""
results: list[MultiModalGroup | None] = [None] * len(image_urls)
to_fetch: list[tuple[int, str, str | None]] = []
# ── 1. Check cache (no-op when cache is None) ────────────────────
for idx, url in enumerate(image_urls):
if cache is not None:
key = get_embedding_hash(url)
cached = cache.get(key)
if cached is not None:
logger.debug(f"[{request_id}] Cache hit for URL index {idx}")
results[idx] = MultiModalGroup(
loaded_embedding=cached.tensor,
image_grid_thw=cached.image_grid_thw,
)
continue
else:
key = None
to_fetch.append((idx, url, key))
# ── 2. Fetch uncached from encode workers ────────────────────────
pending: _PendingRelease | None = None
if to_fetch:
if cache is not None:
logger.info(
f"[{request_id}] Cache miss for {len(to_fetch)}/{len(image_urls)} URLs, "
"fetching from encode workers"
)
miss_urls = [url for _, url, _ in to_fetch]
groups, pending = await _fetch_from_encode_workers(
encode_worker_client,
miss_urls,
request_id,
receiver,
)
# ── 3. Update cache (no-op when cache is None) ──────────────
for (idx, _url, key), group in zip(to_fetch, groups, strict=True):
if cache is not None and key is not None:
cache.set(
key,
CachedEmbedding(
tensor=group.loaded_embedding.clone(),
image_grid_thw=group.image_grid_thw,
),
)
results[idx] = group
else:
logger.info(f"[{request_id}] All {len(image_urls)} URLs served from cache")
return [r for r in results if r is not None], pending
# ── Public API (single entry point) ─────────────────────────────────
async def load_multimodal_embeddings(
encode_worker_client: Client,
image_urls: list[str],
request_id: str,
receiver: AbstractEmbeddingReceiver,
*,
model: str,
embeddings_dtype: torch.dtype,
cache: MultimodalEmbeddingCacheManager | None = None,
) -> Dict[str, Any]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``.
Full pipeline:
cache check → remote fetch → cache update → accumulate → release NIXL buffers.
Returns a dict suitable for passing to ``TokensPrompt(multi_modal_data=...)``.
"""
groups, pending = await _fetch_embeddings(
encode_worker_client,
image_urls,
request_id,
receiver,
cache=cache,
)
multi_modal_data: Dict[str, Any] = defaultdict(list)
for group in groups:
_accumulate_embeddings(
multi_modal_data,
model,
embeddings_dtype,
group.loaded_embedding,
group.image_grid_thw,
)
if pending is not None:
# Multi-image: torch.cat in _accumulate_embeddings already created
# owned tensors. Single-image: the data is still a view into the
# NIXL buffer, so we must clone before releasing.
if len(groups) == 1:
_ensure_owned_tensors(multi_modal_data)
pending.release_all()
return multi_modal_data
...@@ -18,6 +18,7 @@ import json ...@@ -18,6 +18,7 @@ import json
from typing import Any, List, Literal, Optional, Tuple, Union from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec import msgspec
import torch
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from pydantic_core import core_schema from pydantic_core import core_schema
from typing_extensions import NotRequired from typing_extensions import NotRequired
...@@ -171,6 +172,7 @@ class MultiModalGroup(BaseModel): ...@@ -171,6 +172,7 @@ class MultiModalGroup(BaseModel):
Union[Tuple[int, int, int], Tuple[int, int, int, int]] Union[Tuple[int, int, int], Tuple[int, int, int, int]]
] = None ] = None
serialized_request: Optional[TransferRequest] = None serialized_request: Optional[TransferRequest] = None
loaded_embedding: Optional[torch.Tensor] = Field(default=None, exclude=True)
class vLLMMultimodalRequest(vLLMGenerateRequest): class vLLMMultimodalRequest(vLLMGenerateRequest):
......
...@@ -4,17 +4,17 @@ ...@@ -4,17 +4,17 @@
"""Unit tests for MultimodalPDWorkerHandler.""" """Unit tests for MultimodalPDWorkerHandler."""
import json import json
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
from dynamo.vllm.multimodal_handlers import multimodal_pd_worker_handler as mod from dynamo.vllm.multimodal_handlers import multimodal_pd_worker_handler as mod
from dynamo.vllm.multimodal_utils.protocol import ( from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup,
MultiModalInput,
MyRequestOutput, MyRequestOutput,
PatchedTokensPrompt, PatchedTokensPrompt,
vLLMMultimodalRequest, vLLMMultimodalRequest,
...@@ -128,26 +128,98 @@ class TestInit: ...@@ -128,26 +128,98 @@ class TestInit:
assert handler.embedding_cache_manager._capacity_bytes == expected_bytes assert handler.embedding_cache_manager._capacity_bytes == expected_bytes
class TestBuildRequestFromFrontend: class TestParseFrontendRequest:
def test_extracts_token_ids_and_sampling_params(self):
"""Parses token_ids and sampling_params from raw frontend dict."""
handler = _make_handler()
handler.default_sampling_params = {}
raw = _make_raw_frontend_request()
request, image_urls = handler._parse_frontend_request(raw)
assert request.engine_prompt["prompt_token_ids"] == [1, 2, 3]
assert image_urls == []
def test_extracts_image_urls(self):
"""Extracts image URLs from multi_modal_data."""
handler = _make_handler()
handler.default_sampling_params = {}
raw = _make_raw_frontend_request(image_urls=["http://a.png", "http://b.png"])
request, image_urls = handler._parse_frontend_request(raw)
assert image_urls == ["http://a.png", "http://b.png"]
class TestLoadMultimodalData:
@pytest.mark.asyncio
async def test_no_encode_client_returns_empty(self):
"""Without encode client -> returns empty dict."""
handler = _make_handler(encode_worker_client=None)
mm_data = await handler._load_multimodal_data(["http://img.png"], "req-1")
assert len(mm_data) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_with_encode_worker_calls_fetch(self): async def test_no_images_returns_empty(self):
"""With encode client -> delegates to fetch_embeddings_from_encode_workers.""" """With encode client but no images -> returns empty dict."""
handler = _make_handler(encode_worker_client=MagicMock())
mm_data = await handler._load_multimodal_data([], "req-1")
assert len(mm_data) == 0
@pytest.mark.asyncio
async def test_delegates_to_load_multimodal_embeddings(self):
"""With encode client -> delegates to load_multimodal_embeddings."""
mock_client = MagicMock() mock_client = MagicMock()
handler = _make_handler(encode_worker_client=mock_client) handler = _make_handler(encode_worker_client=mock_client)
handler.default_sampling_params = {}
fake_group = MultiModalGroup(multimodal_input=MultiModalInput()) fake_mm_data = defaultdict(list, {"image": torch.randn(1, 10)})
with patch.object(
mod,
"load_multimodal_embeddings",
new_callable=AsyncMock,
return_value=fake_mm_data,
) as mock_load:
result = await handler._load_multimodal_data(["http://img.png"], "req-1")
mock_load.assert_awaited_once()
assert result is fake_mm_data
@pytest.mark.asyncio
async def test_passes_cache_to_load_multimodal_embeddings(self):
"""With cache enabled -> passes cache manager kwarg."""
mock_client = MagicMock()
config = _make_config(multimodal_embedding_cache_capacity_gb=1.0)
handler = _make_handler(config=config, encode_worker_client=mock_client)
with patch.object( with patch.object(
mod, mod,
"fetch_embeddings_from_encode_workers", "load_multimodal_embeddings",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=[fake_group], return_value=defaultdict(list),
) as mock_fetch: ) as mock_load:
raw = _make_raw_frontend_request(image_urls=["http://img.png"]) await handler._load_multimodal_data(["http://img.png"], "req-1")
result = await handler._build_request_from_frontend(raw)
mock_fetch.assert_awaited_once() mock_load.assert_awaited_once()
assert result.multimodal_inputs == [fake_group] assert mock_load.call_args.kwargs["cache"] is handler.embedding_cache_manager
@pytest.mark.asyncio
async def test_passes_model_and_dtype(self):
"""Model name and embeddings dtype are forwarded."""
mock_client = MagicMock()
handler = _make_handler(encode_worker_client=mock_client)
with patch.object(
mod,
"load_multimodal_embeddings",
new_callable=AsyncMock,
return_value=defaultdict(list),
) as mock_load:
await handler._load_multimodal_data(["http://img.png"], "req-1")
assert mock_load.call_args.kwargs["model"] == handler.config.model
assert (
mock_load.call_args.kwargs["embeddings_dtype"] == handler.EMBEDDINGS_DTYPE
)
class TestGenerateAgg: class TestGenerateAgg:
...@@ -158,7 +230,6 @@ class TestGenerateAgg: ...@@ -158,7 +230,6 @@ class TestGenerateAgg:
request = _make_vllm_request() request = _make_vllm_request()
engine_resp = _make_engine_response() engine_resp = _make_engine_response()
# Add a proper output so we exercise the happy path
output = MagicMock() output = MagicMock()
output.token_ids = [10, 11] output.token_ids = [10, 11]
output.finish_reason = "stop" output.finish_reason = "stop"
...@@ -172,7 +243,7 @@ class TestGenerateAgg: ...@@ -172,7 +243,7 @@ class TestGenerateAgg:
handler.engine_client.generate = fake_generate handler.engine_client.generate = fake_generate
chunks = [] chunks = []
async for chunk in handler._generate_agg(request, {"image": []}, []): async for chunk in handler._generate_agg(request, {"image": []}):
chunks.append(chunk) chunks.append(chunk)
assert len(chunks) == 1 assert len(chunks) == 1
...@@ -189,7 +260,6 @@ class TestGenerateDisagg: ...@@ -189,7 +260,6 @@ class TestGenerateDisagg:
handler = _make_handler(config=config, decode_worker_client=decode_client) handler = _make_handler(config=config, decode_worker_client=decode_client)
handler.engine_client = MagicMock() handler.engine_client = MagicMock()
# Mock prefill engine response
prefill_resp = _make_engine_response() prefill_resp = _make_engine_response()
prefill_resp.kv_transfer_params = {"block_ids": [0, 1]} prefill_resp.kv_transfer_params = {"block_ids": [0, 1]}
...@@ -198,7 +268,6 @@ class TestGenerateDisagg: ...@@ -198,7 +268,6 @@ class TestGenerateDisagg:
handler.engine_client.generate = fake_generate handler.engine_client.generate = fake_generate
# Mock decode worker response
decode_output = MyRequestOutput( decode_output = MyRequestOutput(
request_id="req-1", request_id="req-1",
prompt="test", prompt="test",
...@@ -220,7 +289,7 @@ class TestGenerateDisagg: ...@@ -220,7 +289,7 @@ class TestGenerateDisagg:
request = _make_vllm_request() request = _make_vllm_request()
chunks = [] chunks = []
async for chunk in handler._generate_disagg(request, {"image": []}, []): async for chunk in handler._generate_disagg(request, {"image": []}):
chunks.append(chunk) chunks.append(chunk)
assert len(chunks) == 1 assert len(chunks) == 1
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for load_multimodal_embeddings in prefill_worker_utils."""
from unittest.mock import AsyncMock, patch
import pytest
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.vllm.multimodal_utils import prefill_worker_utils as mod
from dynamo.vllm.multimodal_utils.protocol import MultiModalGroup, MultiModalInput
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.multimodal,
]
MODEL = "test-model"
DTYPE = torch.float16
class TestLoadMultimodalEmbeddings:
@pytest.mark.asyncio
async def test_all_cached(self):
"""All URLs cached -> no encode worker call, returns accumulated mm_data."""
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
tensor = torch.randn(1, 10, dtype=DTYPE)
grid = [[1, 2, 3]]
url = "http://img1.png"
key = mod.get_embedding_hash(url)
cache.set(key, CachedEmbedding(tensor=tensor, image_grid_thw=grid))
with patch.object(
mod,
"_fetch_from_encode_workers",
new_callable=AsyncMock,
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
[url],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=cache,
)
mock_fetch.assert_not_awaited()
assert torch.equal(mm_data["image"], tensor)
@pytest.mark.asyncio
async def test_all_uncached_with_cache(self):
"""All URLs uncached with cache -> encode worker call, results cached."""
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
url = "http://img1.png"
tensor = torch.randn(1, 10, dtype=DTYPE)
fake_group = MultiModalGroup(
multimodal_input=MultiModalInput(),
image_grid_thw=[[1, 2, 3]],
loaded_embedding=tensor,
)
with patch.object(
mod,
"_fetch_from_encode_workers",
new_callable=AsyncMock,
return_value=([fake_group], None),
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
[url],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=cache,
)
mock_fetch.assert_awaited_once()
assert torch.equal(mm_data["image"], tensor)
key = mod.get_embedding_hash(url)
cached = cache.get(key)
assert cached is not None
assert torch.equal(cached.tensor, tensor)
@pytest.mark.asyncio
async def test_no_cache(self):
"""Without cache -> all URLs go to encode workers."""
url = "http://img1.png"
tensor = torch.randn(1, 10, dtype=DTYPE)
fake_group = MultiModalGroup(
multimodal_input=MultiModalInput(),
loaded_embedding=tensor,
)
with patch.object(
mod,
"_fetch_from_encode_workers",
new_callable=AsyncMock,
return_value=([fake_group], None),
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
[url],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=None,
)
mock_fetch.assert_awaited_once()
assert torch.equal(mm_data["image"], tensor)
@pytest.mark.asyncio
async def test_mixed_cache(self):
"""Mixed cache hits/misses -> only misses sent to encode workers."""
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
url_cached = "http://cached.png"
url_miss = "http://miss.png"
cached_tensor = torch.randn(1, 10, dtype=DTYPE)
miss_tensor = torch.randn(1, 10, dtype=DTYPE)
key = mod.get_embedding_hash(url_cached)
cache.set(key, CachedEmbedding(tensor=cached_tensor, image_grid_thw=None))
fake_group = MultiModalGroup(
multimodal_input=MultiModalInput(),
image_grid_thw=None,
loaded_embedding=miss_tensor,
)
with patch.object(
mod,
"_fetch_from_encode_workers",
new_callable=AsyncMock,
return_value=([fake_group], None),
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
[url_cached, url_miss],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=cache,
)
mock_fetch.assert_awaited_once()
call_args = mock_fetch.call_args
assert call_args[0][1] == [url_miss]
expected = torch.cat((cached_tensor, miss_tensor))
assert torch.equal(mm_data["image"], expected)
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
SINGLE_GPU=false
# Parse command line arguments
# All extra arguments are passed through to the PD worker's dynamo.vllm
# (which routes them to Dynamo or vLLM as appropriate).
EXTRA_PD_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--single-gpu)
SINGLE_GPU=true
shift
;;
-h|--help)
echo "Usage: $0 [OPTIONS] [EXTRA_ARGS...]"
echo ""
echo "Disaggregated multimodal serving with separate Encode and aggregated PD worker"
echo ""
echo "Options:"
echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)"
echo " LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates"
echo " --single-gpu Run encode and PD workers on the same GPU (for small models, e.g. 2B)"
echo " -h, --help Show this help message"
echo ""
echo "All additional arguments are passed through to the PD worker's dynamo.vllm."
echo "Dynamo args (e.g. --multimodal-embedding-cache-capacity-gb) and"
echo "vLLM engine args (e.g. --no-enable-prefix-caching) are automatically routed."
echo ""
echo "Examples:"
echo " $0 --model llava-hf/llava-1.5-7b-hf"
echo " $0 --model microsoft/Phi-3.5-vision-instruct"
echo " $0 --model Qwen/Qwen2.5-VL-7B-Instruct"
echo " $0 --no-enable-prefix-caching --multimodal-embedding-cache-capacity-gb 2"
echo " $0 --model Qwen/Qwen2-VL-2B-Instruct --single-gpu"
echo ""
exit 0
;;
*)
EXTRA_PD_ARGS+=("$1")
shift
;;
esac
done
PD_MAX_MODEL_LEN="16384"
echo "=================================================="
echo "Disaggregated Multimodal Serving (E + PD)"
echo "=================================================="
echo "Model: $MODEL_NAME"
echo "=================================================="
# Start frontend (no router mode)
echo "Starting frontend..."
python -m dynamo.frontend &
EXTRA_ARGS=""
# Embedding transfer: 1 = local file (safetensors), 0 = NIXL RDMA
export TRANSFER_LOCAL=${TRANSFER_LOCAL:-1}
# GPU assignments (override via environment variables)
if [[ "$SINGLE_GPU" == "true" ]]; then
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-0}
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-0}
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.4}
DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.4}
EXTRA_ARGS="--enforce-eager"
else
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-1}
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-2}
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.9}
DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.9}
fi
# Start encode worker
echo "Starting encode worker on GPU $DYN_ENCODE_WORKER_GPU (GPU mem: $DYN_ENCODE_GPU_MEM)..."
CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU \
python -m dynamo.vllm \
--multimodal-encode-worker \
--enable-multimodal \
--model "$MODEL_NAME" \
--gpu-memory-utilization "$DYN_ENCODE_GPU_MEM" \
$EXTRA_ARGS &
# Start PD worker (aggregated prefill+decode, routes to encoder for embeddings)
echo "Starting PD worker on GPU $DYN_PD_WORKER_GPU (GPU mem: $DYN_PD_GPU_MEM)..."
CUDA_VISIBLE_DEVICES=$DYN_PD_WORKER_GPU \
python -m dynamo.vllm \
--route-to-encoder \
--multimodal-worker \
--enable-multimodal \
--enable-mm-embeds \
--model "$MODEL_NAME" \
--max-model-len "$PD_MAX_MODEL_LEN" \
--gpu-memory-utilization "$DYN_PD_GPU_MEM" \
$EXTRA_ARGS \
"${EXTRA_PD_ARGS[@]}" &
echo "=================================================="
echo "All components started. Waiting for initialization..."
echo "=================================================="
# Wait for all background processes to complete
wait
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from io import BytesIO
import pytest import pytest
from pytest_httpserver import HTTPServer from pytest_httpserver import HTTPServer
...@@ -17,6 +18,28 @@ MULTIMODAL_IMG_PATH = os.path.join( ...@@ -17,6 +18,28 @@ MULTIMODAL_IMG_PATH = os.path.join(
MULTIMODAL_IMG_URL = f"http://localhost:{IMAGE_SERVER_PORT}/llm-graphic.png" MULTIMODAL_IMG_URL = f"http://localhost:{IMAGE_SERVER_PORT}/llm-graphic.png"
# Git LFS pointer files start with "version "; serve a real PNG when the asset is not pulled.
def get_multimodal_test_image_bytes() -> bytes:
"""Return valid PNG bytes for /llm-graphic.png (file or minimal fallback)."""
if os.path.isfile(MULTIMODAL_IMG_PATH):
with open(MULTIMODAL_IMG_PATH, "rb") as f:
data = f.read()
if not data.startswith(b"version "):
# GitHub path
return data
# Local path where we cannot retrieve the above .png file
# Lazy import so conftest loads in environments that don't have Pillow (e.g. pre-commit).
from PIL import Image
buf = BytesIO()
# TODO: differerent models / tests may expect different colors. Need to reconcicle
# code to support all cases locally if needed.
Image.new("RGB", (2, 2), color="green").save(buf, format="PNG")
return buf.getvalue()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def httpserver_listen_address(): def httpserver_listen_address():
return ("127.0.0.1", IMAGE_SERVER_PORT) return ("127.0.0.1", IMAGE_SERVER_PORT)
...@@ -33,15 +56,14 @@ def image_server(httpserver: HTTPServer): ...@@ -33,15 +56,14 @@ def image_server(httpserver: HTTPServer):
Currently serves: Currently serves:
- /llm-graphic.png - LLM diagram image for multimodal tests - /llm-graphic.png - LLM diagram image for multimodal tests
(or a minimal PNG if the file is a Git LFS pointer / not pulled)
Usage: Usage:
def test_multimodal(image_server): def test_multimodal(image_server):
url = "http://localhost:8765/llm-graphic.png" url = "http://localhost:8765/llm-graphic.png"
# ... use url in your test payload # ... use url in your test payload
""" """
# Load LLM graphic image from shared test data image_data = get_multimodal_test_image_bytes()
with open(MULTIMODAL_IMG_PATH, "rb") as f:
image_data = f.read()
# Configure server endpoint # Configure server endpoint
httpserver.expect_request("/llm-graphic.png").respond_with_data( httpserver.expect_request("/llm-graphic.png").respond_with_data(
......
...@@ -16,7 +16,7 @@ from tests.serve.common import ( ...@@ -16,7 +16,7 @@ from tests.serve.common import (
params_with_model_mark, params_with_model_mark,
run_serve_deployment, run_serve_deployment,
) )
from tests.serve.conftest import MULTIMODAL_IMG_PATH, MULTIMODAL_IMG_URL from tests.serve.conftest import MULTIMODAL_IMG_URL, get_multimodal_test_image_bytes
from tests.serve.lora_utils import MinioLoraConfig from tests.serve.lora_utils import MinioLoraConfig
from tests.utils.constants import DefaultPort from tests.utils.constants import DefaultPort
from tests.utils.engine_process import EngineConfig from tests.utils.engine_process import EngineConfig
...@@ -276,37 +276,34 @@ vllm_configs = { ...@@ -276,37 +276,34 @@ vllm_configs = {
completion_payload_default(), completion_payload_default(),
], ],
), ),
# The original script is misleading agg_multimodal_epd.sh is actually a disagg "multimodal_disagg_qwen2vl_2b_e_pd": VLLMConfig(
# case which uses disgg encoder. We are bringing this test back shortly name="multimodal_disagg_qwen2vl_2b_e_pd",
# TODO(qiwa): enable this in https://github.com/ai-dynamo/dynamo/pull/6061/ directory=vllm_dir,
# "multimodal_agg_qwen2vl_2b_epd": VLLMConfig( script_name="disagg_multimodal_e_pd.sh",
# name="multimodal_agg_qwen2vl_2b_epd", marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
# directory=vllm_dir, model="Qwen/Qwen2-VL-2B-Instruct",
# script_name="agg_multimodal_epd.sh", script_args=["--model", "Qwen/Qwen2-VL-2B-Instruct", "--single-gpu"],
# marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], request_payloads=[
# model="Qwen/Qwen2-VL-2B-Instruct", chat_payload(
# script_args=["--model", "Qwen/Qwen2-VL-2B-Instruct", "--single-gpu"], [
# request_payloads=[ {
# chat_payload( "type": "text",
# [ "text": "What colors are in the following image? Respond only with the colors.",
# { },
# "type": "text", {
# "text": "What colors are in the following image? Respond only with the colors.", "type": "image_url",
# }, "image_url": {"url": MULTIMODAL_IMG_URL},
# { },
# "type": "image_url", ],
# "image_url": {"url": MULTIMODAL_IMG_URL}, repeat_count=1,
# }, # With proper prompt templating, the model actually only returns "green",
# ], # verified behavior with native vLLM.
# repeat_count=1, expected_response=["green"],
# # With proper prompt templating, the model actually only returns "green", temperature=0.0,
# # verified behavior with native vLLM. max_tokens=100,
# expected_response=["green"], )
# temperature=0.0, ],
# max_tokens=100, ),
# )
# ],
# ),
"multimodal_agg_frontend_decoding": VLLMConfig( "multimodal_agg_frontend_decoding": VLLMConfig(
name="multimodal_agg_frontend_decoding", name="multimodal_agg_frontend_decoding",
directory=vllm_dir, directory=vllm_dir,
...@@ -755,9 +752,8 @@ def test_multimodal_b64( ...@@ -755,9 +752,8 @@ def test_multimodal_b64(
This test is separate because it loads the required image at runtime This test is separate because it loads the required image at runtime
(not collection time), ensuring it only fails when actually executed. (not collection time), ensuring it only fails when actually executed.
""" """
# Load B64 image at test execution time # Load B64 image at test execution time (uses real PNG even if MULTIMODAL_IMG is LFS pointer)
with open(MULTIMODAL_IMG_PATH, "rb") as f: b64_img = base64.b64encode(get_multimodal_test_image_bytes()).decode()
b64_img = base64.b64encode(f.read()).decode()
# Create payload with B64 image # Create payload with B64 image
b64_payload = chat_payload( b64_payload = chat_payload(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment