Unverified Commit b82b45a1 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat: add EncoderCacheManager to TRT-LLM PrefillHandler (#5714)

parent 0268aea4
...@@ -47,9 +47,6 @@ class EncoderCacheManager: ...@@ -47,9 +47,6 @@ class EncoderCacheManager:
Args: Args:
capacity_bytes: Maximum cache capacity in bytes. capacity_bytes: Maximum cache capacity in bytes.
""" """
if capacity_bytes <= 0:
raise ValueError("capacity_bytes must be positive")
self._cache: OrderedDict[str, torch.Tensor] = OrderedDict() self._cache: OrderedDict[str, torch.Tensor] = OrderedDict()
self._capacity_bytes = capacity_bytes self._capacity_bytes = capacity_bytes
self._current_bytes = 0 self._current_bytes = 0
......
...@@ -9,27 +9,6 @@ import torch ...@@ -9,27 +9,6 @@ import torch
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
class TestEncoderCacheManagerInit:
"""Tests for initialization."""
def test_init_valid_capacity(self):
"""Test initialization with valid capacity."""
cache = EncoderCacheManager(capacity_bytes=1024)
assert cache.stats["capacity_bytes"] == 1024
assert cache.stats["current_bytes"] == 0
assert cache.stats["entries"] == 0
def test_init_invalid_capacity_zero(self):
"""Test initialization with zero capacity raises error."""
with pytest.raises(ValueError, match="capacity_bytes must be positive"):
EncoderCacheManager(capacity_bytes=0)
def test_init_invalid_capacity_negative(self):
"""Test initialization with negative capacity raises error."""
with pytest.raises(ValueError, match="capacity_bytes must be positive"):
EncoderCacheManager(capacity_bytes=-100)
class TestEncoderCacheManagerBasicOperations: class TestEncoderCacheManagerBasicOperations:
"""Tests for basic get/set operations.""" """Tests for basic get/set operations."""
......
...@@ -441,6 +441,7 @@ async def init( ...@@ -441,6 +441,7 @@ async def init(
metrics_collector=metrics_collector, metrics_collector=metrics_collector,
kv_block_size=config.kv_block_size, kv_block_size=config.kv_block_size,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
encoder_cache_capacity_gb=config.encoder_cache_capacity_gb,
) )
# Register the model with runtime config # Register the model with runtime config
......
...@@ -68,6 +68,7 @@ class RequestHandlerConfig: ...@@ -68,6 +68,7 @@ class RequestHandlerConfig:
metrics_collector: Optional[Any] = None # TensorRT-LLM MetricsCollector metrics_collector: Optional[Any] = None # TensorRT-LLM MetricsCollector
kv_block_size: int = 32 kv_block_size: int = 32
shutdown_event: Optional[asyncio.Event] = None shutdown_event: Optional[asyncio.Event] = None
encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB
class HandlerBase: class HandlerBase:
......
...@@ -7,6 +7,7 @@ from typing import Optional ...@@ -7,6 +7,7 @@ from typing import Optional
from tensorrt_llm.llmapi import DisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.encode_helper import EncodeHelper from dynamo.trtllm.encode_helper import EncodeHelper
from dynamo.trtllm.request_handlers.handler_base import ( from dynamo.trtllm.request_handlers.handler_base import (
...@@ -31,6 +32,13 @@ class RequestHandlerFactory: ...@@ -31,6 +32,13 @@ class RequestHandlerFactory:
raise ValueError( raise ValueError(
f"Invalid disaggregation_mode '{config.disaggregation_mode.value}'" f"Invalid disaggregation_mode '{config.disaggregation_mode.value}'"
) )
if config.disaggregation_mode.value == "prefill":
encoder_cache = None
if config.encoder_cache_capacity_gb > 0:
# Create encoder cache for prefill handler
capacity_bytes = int(config.encoder_cache_capacity_gb * 1024**3)
encoder_cache = EncoderCacheManager(capacity_bytes)
return PrefillHandler(config, encoder_cache=encoder_cache)
return self.handlers[config.disaggregation_mode.value](config) return self.handlers[config.disaggregation_mode.value](config)
...@@ -93,8 +101,13 @@ class PrefillHandler(HandlerBase): ...@@ -93,8 +101,13 @@ class PrefillHandler(HandlerBase):
Handler for prefill-only workers in disaggregated serving. Handler for prefill-only workers in disaggregated serving.
""" """
def __init__(self, config: RequestHandlerConfig): def __init__(
self,
config: RequestHandlerConfig,
encoder_cache: Optional[EncoderCacheManager] = None,
):
super().__init__(config) super().__init__(config)
self._encoder_cache = encoder_cache
async def remote_encode_full_epd(self, request: dict): async def remote_encode_full_epd(self, request: dict):
""" """
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for PrefillHandler."""
from unittest.mock import MagicMock
import pytest
from dynamo.trtllm.request_handlers.handlers import PrefillHandler
from dynamo.trtllm.tests.utils import create_mock_request_handler_config
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.trtllm,
pytest.mark.gpu_0,
]
@pytest.fixture
def mock_config():
"""Create a mock RequestHandlerConfig."""
return create_mock_request_handler_config(disaggregation_mode="prefill")
@pytest.fixture
def mock_encoder_cache():
"""Create a mock EncoderCacheManager."""
cache = MagicMock()
cache.get = MagicMock(return_value=None)
cache.set = MagicMock(return_value=True)
cache.stats = {"hits": 0, "misses": 0, "entries": 0}
return cache
class TestPrefillHandlerInit:
"""Tests for PrefillHandler initialization."""
def test_init_with_encoder_cache(self, mock_config, mock_encoder_cache):
"""Test PrefillHandler can be initialized with encoder_cache."""
handler = PrefillHandler(mock_config, encoder_cache=mock_encoder_cache)
assert handler.engine == mock_config.engine
assert handler._encoder_cache == mock_encoder_cache
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for RequestHandlerFactory."""
import pytest
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.trtllm.request_handlers.handlers import (
AggregatedHandler,
PrefillHandler,
RequestHandlerFactory,
)
from dynamo.trtllm.tests.utils import create_mock_request_handler_config
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.trtllm,
pytest.mark.gpu_0,
]
@pytest.fixture
def mock_config():
"""Create a mock RequestHandlerConfig."""
return create_mock_request_handler_config()
class TestRequestHandlerFactory:
"""Tests for RequestHandlerFactory."""
def test_creates_aggregated_handler(self, mock_config):
"""Test factory creates AggregatedHandler for prefill_and_decode mode."""
factory = RequestHandlerFactory()
handler = factory.get_request_handler(mock_config)
assert isinstance(handler, AggregatedHandler)
def test_creates_prefill_handler(self, mock_config):
"""Test factory creates PrefillHandler for prefill mode."""
mock_config.disaggregation_mode.value = "prefill"
factory = RequestHandlerFactory()
handler = factory.get_request_handler(mock_config)
assert isinstance(handler, PrefillHandler)
def test_invalid_mode_raises(self, mock_config):
"""Test factory raises ValueError for invalid disaggregation_mode."""
mock_config.disaggregation_mode.value = "invalid_mode"
factory = RequestHandlerFactory()
with pytest.raises(ValueError, match="Invalid disaggregation_mode"):
factory.get_request_handler(mock_config)
def test_prefill_handler_with_encoder_cache(self):
"""Test factory creates PrefillHandler with EncoderCacheManager when capacity > 0."""
mock_config = create_mock_request_handler_config(
disaggregation_mode="prefill",
encoder_cache_capacity_gb=1.0,
)
factory = RequestHandlerFactory()
handler = factory.get_request_handler(mock_config)
assert isinstance(handler, PrefillHandler)
assert isinstance(handler._encoder_cache, EncoderCacheManager)
def test_prefill_handler_without_encoder_cache(self):
"""Test factory creates PrefillHandler with no cache when capacity is 0."""
mock_config = create_mock_request_handler_config(
disaggregation_mode="prefill",
encoder_cache_capacity_gb=0,
)
factory = RequestHandlerFactory()
handler = factory.get_request_handler(mock_config)
assert isinstance(handler, PrefillHandler)
assert handler._encoder_cache is None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared test utilities for dynamo.trtllm tests."""
from unittest.mock import MagicMock
def create_mock_request_handler_config(
disaggregation_mode: str = "prefill_and_decode",
encoder_cache_capacity_gb: float = 0,
) -> MagicMock:
"""Create a mock RequestHandlerConfig for testing.
Args:
disaggregation_mode: The disaggregation mode value.
encoder_cache_capacity_gb: Encoder cache capacity in GB.
Returns:
MagicMock configured as a RequestHandlerConfig.
"""
config = MagicMock()
config.disaggregation_mode.value = disaggregation_mode
config.engine = MagicMock()
config.component = MagicMock()
config.default_sampling_params = MagicMock()
config.publisher = MagicMock()
config.metrics_collector = None
config.encode_client = None
config.multimodal_processor = None
config.connector = None
config.runtime = None
config.kv_block_size = 32
config.shutdown_event = None
config.encoder_cache_capacity_gb = encoder_cache_capacity_gb
return config
...@@ -54,6 +54,7 @@ class Config: ...@@ -54,6 +54,7 @@ class Config:
self.modality: str = "text" self.modality: str = "text"
self.allowed_local_media_path: str = "" self.allowed_local_media_path: str = ""
self.max_file_size_mb: int = 50 self.max_file_size_mb: int = 50
self.encoder_cache_capacity_gb: float = 0
self.reasoning_parser: Optional[str] = None self.reasoning_parser: Optional[str] = None
self.tool_call_parser: Optional[str] = None self.tool_call_parser: Optional[str] = None
self.dump_config_to: Optional[str] = None self.dump_config_to: Optional[str] = None
...@@ -92,6 +93,7 @@ class Config: ...@@ -92,6 +93,7 @@ class Config:
f"modality={self.modality}, " f"modality={self.modality}, "
f"allowed_local_media_path={self.allowed_local_media_path}, " f"allowed_local_media_path={self.allowed_local_media_path}, "
f"max_file_size_mb={self.max_file_size_mb}, " f"max_file_size_mb={self.max_file_size_mb}, "
f"encoder_cache_capacity_gb={self.encoder_cache_capacity_gb}, "
f"reasoning_parser={self.reasoning_parser}, " f"reasoning_parser={self.reasoning_parser}, "
f"tool_call_parser={self.tool_call_parser}, " f"tool_call_parser={self.tool_call_parser}, "
f"dump_config_to={self.dump_config_to}, " f"dump_config_to={self.dump_config_to}, "
...@@ -286,6 +288,12 @@ def cmd_line_args(): ...@@ -286,6 +288,12 @@ def cmd_line_args():
default=50, default=50,
help="Maximum size of downloadable embedding files/Image URLs. Default: 50MB", help="Maximum size of downloadable embedding files/Image URLs. Default: 50MB",
) )
parser.add_argument(
"--dyn-encoder-cache-capacity-gb",
type=float,
default=0,
help="Capacity of the encoder cache in GB for multimodal embeddings. Default: 0",
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args # To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument( parser.add_argument(
"--dyn-tool-call-parser", "--dyn-tool-call-parser",
...@@ -384,6 +392,7 @@ def cmd_line_args(): ...@@ -384,6 +392,7 @@ def cmd_line_args():
config.encode_endpoint = args.encode_endpoint config.encode_endpoint = args.encode_endpoint
config.allowed_local_media_path = args.allowed_local_media_path config.allowed_local_media_path = args.allowed_local_media_path
config.max_file_size_mb = args.max_file_size_mb config.max_file_size_mb = args.max_file_size_mb
config.encoder_cache_capacity_gb = args.dyn_encoder_cache_capacity_gb
config.tensor_parallel_size = args.tensor_parallel_size config.tensor_parallel_size = args.tensor_parallel_size
if args.pipeline_parallel_size is not None: if args.pipeline_parallel_size is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment