Unverified Commit a78a4260 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat: use encoder cache in TRT-LLM EPD workflow (#5780)

parent b12e6710
......@@ -2,9 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
from .cuda_ipc import extract_embeddings_from_handles
from .embedding_fetcher import fetch_embeddings_from_encoder
from .hasher import MultimodalHasher
__all__ = [
"MultimodalHasher",
"extract_embeddings_from_handles",
"fetch_embeddings_from_encoder",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Embedding fetcher utilities for multimodal processing with caching.
Provides utility functions for fetching image embeddings from remote encoder
with per-URL caching support.
"""
import logging
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.multimodal.async_encoder_cache import EncoderCacheManager
from dynamo.trtllm.multimodal.cuda_ipc import extract_embeddings_from_handles
from dynamo.trtllm.multimodal.hasher import MultimodalHasher
logger = logging.getLogger(__name__)
async def fetch_embeddings_from_encoder(
image_urls: List[str],
request: Dict[str, Any],
encode_client: Any,
encoder_cache: Optional[EncoderCacheManager] = None,
) -> Union[List[torch.Tensor], DisaggregatedParams]:
"""
Fetch embeddings from remote encode worker.
Args:
image_urls: List of image URLs to encode (must not be empty)
request: Request dict (used for creating modified requests for caching)
encode_client: Client to call remote encode worker
encoder_cache: Optional cache for embeddings
Returns:
- List[torch.Tensor]: When using cache (CPU tensors from cache)
- DisaggregatedParams: When not using cache (contains CUDA IPC handles)
Raises:
ValueError: If image_urls is empty
"""
if not image_urls:
raise ValueError("image_urls must not be empty")
logger.info(f"fetch_embeddings_from_encoder: image_urls={image_urls}")
if encoder_cache:
# Cache path: extract embeddings to CPU tensors
return await _fetch_embeddings_with_cache(
image_urls,
request,
encoder_cache,
lambda req: _remote_encode_full_epd(
req, encode_client, update_request_for_decode=False
),
)
else:
# No cache: return DisaggregatedParams directly (no GPU→CPU extraction)
return await _remote_encode_full_epd(
request, encode_client, update_request_for_decode=True
)
async def _remote_encode_full_epd(
request: Dict[str, Any],
encode_client: Any,
update_request_for_decode: bool = True,
) -> DisaggregatedParams:
"""
Call encode worker for full EPD flow.
Args:
request: Request dict
encode_client: Client to call remote encode worker
update_request_for_decode: If True, store EPD metadata in request
Returns:
DisaggregatedParams with multimodal_embedding_handles
Raises:
RuntimeError: If encode worker returns invalid response
"""
encode_response = None
async for res in await encode_client.round_robin(request):
encode_response = res.data()
break
if not encode_response:
raise RuntimeError("Did not receive a response from the encode worker.")
if "ep_disaggregated_params" not in encode_response:
raise RuntimeError("Encode response missing ep_disaggregated_params.")
params_dict = encode_response["ep_disaggregated_params"]
if params_dict is None:
raise RuntimeError("ep_disaggregated_params is None.")
# Store EPD metadata in request for decode worker (only when not using cache)
if update_request_for_decode:
if "processed_prompt" in encode_response:
request["_epd_processed_prompt"] = encode_response["processed_prompt"]
if "prompt_token_ids" in encode_response:
request["_epd_prompt_token_ids"] = encode_response["prompt_token_ids"]
return DisaggregatedParams(**params_dict)
async def _fetch_embeddings_with_cache(
image_urls: List[str],
request: Dict[str, Any],
cache: EncoderCacheManager,
encode_fn: Callable[[Dict[str, Any]], DisaggregatedParams],
) -> List[torch.Tensor]:
"""
Encode image URLs with per-URL caching and partial cache usage.
Checks cache for each URL. Cached embeddings are reused directly.
For uncached URLs, sends a single encode request for only those URLs,
then caches the results.
Args:
image_urls: List of image URLs to encode
request: Original request dict containing the images
cache: AsyncEncoderCache instance for caching embeddings
encode_fn: Async function that encodes a request and returns ep_disaggregated_params
Should accept a modified request dict with subset of URLs
Returns:
List of embedding tensors for all images in original order
"""
if not image_urls:
raise ValueError("image_urls list is empty")
# Check cache for each URL
embeddings_with_index = [] # List of (original_index, tensor)
uncached_urls = []
uncached_indices = []
uncached_hashes = []
for i, url in enumerate(image_urls):
url_hash = MultimodalHasher.hash_bytes(url.encode())
cached = cache.get(url_hash)
if cached is not None:
logger.info(f"fetch_embeddings_with_cache: cache hit for URL: {url}")
embeddings_with_index.append((i, cached))
else:
logger.info(f"fetch_embeddings_with_cache: cache miss for URL: {url}")
uncached_urls.append(url)
uncached_indices.append(i)
uncached_hashes.append(url_hash)
# If all cached, return immediately
if not uncached_urls:
logger.info(f"fetch_embeddings_with_cache: all {len(image_urls)} URLs cached")
embeddings_with_index.sort(key=lambda x: x[0])
tensors = [t for _, t in embeddings_with_index]
return tensors
# Encode uncached URLs
logger.info(
f"fetch_embeddings_with_cache: encoding {len(uncached_urls)} uncached URLs"
)
# Create modified request with only uncached URLs
modified_request = _create_request_with_urls(request, uncached_urls)
# Call encode function
ep_disaggregated_params = await encode_fn(modified_request)
if not ep_disaggregated_params:
raise RuntimeError(
"fetch_embeddings_with_cache: Failed to get ep_disaggregated_params"
)
# Extract handles from disaggregated params
handles = getattr(ep_disaggregated_params, "multimodal_embedding_handles", None)
if not handles:
raise RuntimeError(
"fetch_embeddings_with_cache: No multimodal_embedding_handles in ep_disaggregated_params"
)
# Extract tensors from CUDA IPC handles
new_tensors = await extract_embeddings_from_handles(handles)
# Cache new tensors (reuse hashes computed during cache lookup)
for url, url_hash, tensor in zip(uncached_urls, uncached_hashes, new_tensors):
cache.set(url_hash, tensor)
logger.info(
f"fetch_embeddings_with_cache: cached embedding for URL: {url}, shape: {tensor.shape}"
)
# Add new tensors to our list with their original indices
for idx, tensor in zip(uncached_indices, new_tensors):
embeddings_with_index.append((idx, tensor))
# Sort by original order and return list
embeddings_with_index.sort(key=lambda x: x[0])
tensors = [t for _, t in embeddings_with_index]
return tensors
def _create_request_with_urls(
original_request: Dict[str, Any], image_urls: List[str]
) -> Dict[str, Any]:
"""
Create a modified request containing only specified image URLs.
Args:
original_request: Original request dict
image_urls: URLs to include in the modified request
Returns:
Modified request dict with filtered image URLs
"""
# Deep copy to avoid modifying original
import copy
modified_request = copy.deepcopy(original_request)
# Extract messages
messages = modified_request.get("extra_args", {}).get(
"messages", modified_request.get("messages", [])
)
# Filter messages to only include specified URLs
filtered_messages = []
for message in messages:
new_message = {"role": message.get("role", "user"), "content": []}
for content in message.get("content", []):
if isinstance(content, dict):
if content.get("type") == "image_url":
# Only include if URL is in our list
url = content.get("image_url", {}).get("url")
if url in image_urls:
new_message["content"].append(content)
elif content.get("type") == "text":
# Keep text content
new_message["content"].append(content)
elif isinstance(content, str):
new_message["content"].append(content)
if new_message["content"]:
filtered_messages.append(new_message)
# Update the request with filtered messages
if "extra_args" in modified_request:
modified_request["extra_args"]["messages"] = filtered_messages
else:
modified_request["messages"] = filtered_messages
return modified_request
......@@ -4,12 +4,11 @@
import logging
from typing import Optional
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo._core import Context
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.encode_helper import EncodeHelper
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
from dynamo.trtllm.request_handlers.handler_base import (
HandlerBase,
RequestHandlerConfig,
......@@ -109,29 +108,6 @@ class PrefillHandler(HandlerBase):
super().__init__(config)
self._encoder_cache = encoder_cache
async def remote_encode_full_epd(self, request: dict):
"""
Call encode worker for full EPD flow and unpack the response.
Args:
request: Request dict
Returns:
Encoder's DisaggregatedParams to be used by the prefill worker
"""
encode_response = None
async for res in await self.encode_client.round_robin(request):
encode_response = res.data()
break
if not encode_response:
raise RuntimeError("Did not receive a response from the encode worker.")
ep_disaggregated_params = self._unpack_full_epd_response(
encode_response, request
)
return ep_disaggregated_params
async def remote_encode_with_nixl(self, request: dict):
"""
Call encode worker for NIXL flow to load embeddings and unpack the response.
......@@ -156,43 +132,6 @@ class PrefillHandler(HandlerBase):
encode_response, self.connector
)
def _unpack_full_epd_response(
self, encode_response: dict, request: dict
) -> Optional[DisaggregatedParams]:
"""
Unpack encode worker response from full EPD flow.
Extracts DisaggregatedParams and stores EPD metadata in the request
for downstream processing (multimodal_processor, decode worker).
Args:
encode_response: Response dict from encode worker
request: Request dict to store metadata in (modified in-place)
Returns:
DisaggregatedParams if present in response, None otherwise
"""
if "ep_disaggregated_params" not in encode_response:
return None
params_dict = encode_response["ep_disaggregated_params"]
if params_dict is None:
return None
# Reconstruct DisaggregatedParams object from dict
ep_disaggregated_params = DisaggregatedParams(**params_dict)
ep_disaggregated_params.request_type = "context_only"
# Store processed prompt from encoder (includes <image> tokens)
if "processed_prompt" in encode_response:
request["_epd_processed_prompt"] = encode_response["processed_prompt"]
# Store prompt_token_ids from encoder for decode worker
if "prompt_token_ids" in encode_response:
request["_epd_prompt_token_ids"] = encode_response["prompt_token_ids"]
return ep_disaggregated_params
async def generate(self, request: dict, context: Context):
"""
Prefill worker: process prompt and return disaggregated_params.
......@@ -230,7 +169,19 @@ class PrefillHandler(HandlerBase):
# Handle image URLs (full E-PD flow with MultimodalEncoder)
elif image_urls:
if self.encode_client:
ep_disaggregated_params = await self.remote_encode_full_epd(request)
logging.info(f"PrefillHandler: image_urls={image_urls}")
result = await fetch_embeddings_from_encoder(
image_urls,
request,
self.encode_client,
self._encoder_cache,
)
if isinstance(result, list):
# Cache path: got List[torch.Tensor]
embeddings_tensor = result
else:
# No-cache path: got DisaggregatedParams
ep_disaggregated_params = result
# Normal flow: Generate the prefill response locally with embeddings
response_count = 0
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for fetch_embeddings_from_encoder."""
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
import torch
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
from dynamo.trtllm.multimodal.hasher import MultimodalHasher
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.trtllm,
pytest.mark.gpu_0,
]
def create_mock_encode_client(
embeddings: list[torch.Tensor],
processed_prompt: str = "prompt",
prompt_token_ids: list[int] | None = None,
) -> AsyncMock:
"""Create mock encode client that returns embeddings via CUDA IPC handles."""
class MockResponse:
def data(self):
return {
"ep_disaggregated_params": {
"multimodal_embedding_handles": [
f"h{i}" for i in range(len(embeddings))
]
},
"processed_prompt": processed_prompt,
"prompt_token_ids": prompt_token_ids or [1, 2, 3],
}
async def mock_round_robin(req: dict[str, Any]) -> Any:
async def gen():
yield MockResponse()
return gen()
client = AsyncMock()
client.round_robin = mock_round_robin
return client
@pytest.fixture
def encoder_cache() -> EncoderCacheManager:
"""Create encoder cache with 10MB capacity."""
return EncoderCacheManager(capacity_bytes=10 * 1024 * 1024)
class TestFetchEmbeddingsFromEncoder:
"""Tests for fetch_embeddings_from_encoder function."""
@pytest.mark.asyncio
async def test_partial_cache_no_metadata_update(self, encoder_cache):
"""Cache path: request NOT updated with EPD metadata."""
url1, url2 = "http://example.com/img1.jpg", "http://example.com/img2.jpg"
embedding1, embedding2 = torch.ones(10, 256), torch.ones(10, 256) * 2
encoder_cache.set(MultimodalHasher.hash_bytes(url1.encode()), embedding1)
request: dict[str, Any] = {"messages": []}
mock_client = create_mock_encode_client([embedding2])
with patch(
"dynamo.trtllm.multimodal.embedding_fetcher.extract_embeddings_from_handles",
AsyncMock(return_value=[embedding2]),
):
result = await fetch_embeddings_from_encoder(
[url1, url2], request, mock_client, encoder_cache
)
assert len(result) == 2
assert "_epd_processed_prompt" not in request
@pytest.mark.asyncio
async def test_all_cached_no_request_sent(self, encoder_cache):
"""All cached: no encode request sent."""
url1, url2 = "http://example.com/img1.jpg", "http://example.com/img2.jpg"
embedding1, embedding2 = torch.ones(10, 256), torch.ones(10, 256) * 2
encoder_cache.set(MultimodalHasher.hash_bytes(url1.encode()), embedding1)
encoder_cache.set(MultimodalHasher.hash_bytes(url2.encode()), embedding2)
async def should_not_call(req: dict[str, Any]) -> None:
raise AssertionError("Should not be called")
mock_client = AsyncMock()
mock_client.round_robin = should_not_call
result = await fetch_embeddings_from_encoder(
[url1, url2], {"messages": []}, mock_client, encoder_cache
)
assert len(result) == 2
assert torch.equal(result[0], embedding1)
@pytest.mark.asyncio
async def test_no_cache_returns_disaggregated_params(self):
"""No cache: returns DisaggregatedParams directly, request updated with metadata."""
request: dict[str, Any] = {"messages": []}
# Pass one embedding so mock generates one handle (DisaggregatedParams requires non-empty handles)
mock_client = create_mock_encode_client(
[torch.ones(10, 256)],
processed_prompt="test <image>",
prompt_token_ids=[10, 20],
)
result = await fetch_embeddings_from_encoder(
["http://example.com/img.jpg"], request, mock_client, encoder_cache=None
)
assert isinstance(result, DisaggregatedParams)
assert request["_epd_processed_prompt"] == "test <image>"
assert request["_epd_prompt_token_ids"] == [10, 20]
@pytest.mark.asyncio
async def test_empty_urls_raises_error(self, encoder_cache):
"""Empty image_urls raises ValueError."""
mock_client = AsyncMock()
with pytest.raises(ValueError, match="image_urls must not be empty"):
await fetch_embeddings_from_encoder([], {}, mock_client, encoder_cache)
......@@ -3,15 +3,19 @@
"""Unit tests for PrefillHandler."""
from unittest.mock import MagicMock
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import torch
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.trtllm.request_handlers.handlers import PrefillHandler
from dynamo.trtllm.tests.utils import create_mock_request_handler_config
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.trtllm,
pytest.mark.gpu_0,
]
......@@ -29,10 +33,46 @@ def mock_encoder_cache():
cache = MagicMock()
cache.get = MagicMock(return_value=None)
cache.set = MagicMock(return_value=True)
cache.stats = {"hits": 0, "misses": 0, "entries": 0}
return cache
@pytest.fixture
def mock_context():
"""Create a mock Context."""
ctx = MagicMock()
ctx.id = MagicMock(return_value="test-id")
ctx.is_stopped = MagicMock(return_value=False)
ctx.is_killed = MagicMock(return_value=False)
return ctx
@pytest.fixture
def image_request() -> dict[str, Any]:
"""Create a request with one image URL."""
return {
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": "http://example.com/image.jpg"},
},
],
}
]
}
def setup_multimodal_config(mock_config):
"""Configure mock_config for multimodal requests."""
mock_config.multimodal_processor = MagicMock()
mock_config.multimodal_processor.extract_prompt_and_media = MagicMock(
return_value=("text", ["http://example.com/image.jpg"], [])
)
mock_config.encode_client = MagicMock()
class TestPrefillHandlerInit:
"""Tests for PrefillHandler initialization."""
......@@ -42,3 +82,63 @@ class TestPrefillHandlerInit:
assert handler.engine == mock_config.engine
assert handler._encoder_cache == mock_encoder_cache
class TestPrefillHandlerGenerate:
"""Tests for PrefillHandler.generate method."""
@pytest.mark.asyncio
async def test_embeddings_passed_to_generate_locally(
self, mock_config, mock_encoder_cache, mock_context, image_request
):
"""Test embeddings from fetch_embeddings_from_encoder passed to generate_locally."""
setup_multimodal_config(mock_config)
handler = PrefillHandler(mock_config, encoder_cache=mock_encoder_cache)
expected_embeddings = [torch.randn(10, 256)]
captured_embeddings = None
async def mock_generate_locally(request, context, embeddings, ep_params):
nonlocal captured_embeddings
captured_embeddings = embeddings
yield {"result": "mock"}
with patch(
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder",
new_callable=AsyncMock,
return_value=expected_embeddings,
) as mock_fetch:
with patch.object(handler, "generate_locally", mock_generate_locally):
async for _ in handler.generate(image_request, mock_context):
pass
mock_fetch.assert_called_once()
assert captured_embeddings is expected_embeddings
@pytest.mark.asyncio
async def test_disaggregated_params_passed_to_generate_locally(
self, mock_config, mock_context, image_request
):
"""Test DisaggregatedParams from fetch_embeddings_from_encoder passed to generate_locally."""
setup_multimodal_config(mock_config)
handler = PrefillHandler(mock_config, encoder_cache=None)
expected_params = DisaggregatedParams(request_type="context_only")
captured_ep_params = None
async def mock_generate_locally(request, context, embeddings, ep_params):
nonlocal captured_ep_params
captured_ep_params = ep_params
yield {"result": "mock"}
with patch(
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder",
new_callable=AsyncMock,
return_value=expected_params,
) as mock_fetch:
with patch.object(handler, "generate_locally", mock_generate_locally):
async for _ in handler.generate(image_request, mock_context):
pass
mock_fetch.assert_called_once()
assert captured_ep_params is expected_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment