Unverified Commit 00ea11ff authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat: EC E/PD workflow in TRT-LLM (#5815)

parent 410691dc
...@@ -27,7 +27,7 @@ async def extract_embeddings_from_handles( ...@@ -27,7 +27,7 @@ async def extract_embeddings_from_handles(
properly. properly.
Args: Args:
handles: List of CUDA IPC handle dictionaries from encoder response handles: List of CUDA IPC handle dictionaries from encoder response.
Returns: Returns:
List of embedding tensors on CPU. List of embedding tensors on CPU.
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Handler for aggregated (prefill + decode) mode with optional encoder disaggregation."""
import logging
from typing import Optional
from dynamo._core import Context
from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
from dynamo.trtllm.request_handlers.handler_base import (
HandlerBase,
RequestHandlerConfig,
)
class AggregatedHandler(HandlerBase):
"""
Handler for aggregated mode (prefill + decode in single worker).
Supports optional encoder disaggregation (E_PD flow) when encode_client
and encoder_cache are configured.
"""
def __init__(
self,
config: RequestHandlerConfig,
encoder_cache: Optional[EncoderCacheManager] = None,
):
super().__init__(config)
self._encoder_cache = encoder_cache
async def generate(self, request: dict, context: Context):
"""Generate response, optionally using remote encoder for multimodal."""
logging.debug(f"AggregatedHandler Request ID: {context.id()}")
embeddings = None
ep_disaggregated_params = None
if self.multimodal_processor and self.encode_client:
messages = request.get("extra_args", {}).get(
"messages", request.get("messages", [])
)
_, image_urls, _ = self.multimodal_processor.extract_prompt_and_media(
messages
)
if image_urls:
logging.info(f"AggregatedHandler: image_urls={image_urls}")
result = await fetch_embeddings_from_encoder(
image_urls,
request,
self.encode_client,
self._encoder_cache,
)
if isinstance(result, list):
embeddings = result
else:
ep_disaggregated_params = result
async for res in self.generate_locally(
request, context, embeddings, ep_disaggregated_params
):
yield res
...@@ -9,6 +9,7 @@ from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager ...@@ -9,6 +9,7 @@ from dynamo.common.memory.encoder_cache_manager import EncoderCacheManager
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.encode_helper import EncodeHelper from dynamo.trtllm.encode_helper import EncodeHelper
from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder from dynamo.trtllm.multimodal.embedding_fetcher import fetch_embeddings_from_encoder
from dynamo.trtllm.request_handlers.aggregated_handler import AggregatedHandler
from dynamo.trtllm.request_handlers.handler_base import ( from dynamo.trtllm.request_handlers.handler_base import (
HandlerBase, HandlerBase,
RequestHandlerConfig, RequestHandlerConfig,
...@@ -31,13 +32,14 @@ class RequestHandlerFactory: ...@@ -31,13 +32,14 @@ class RequestHandlerFactory:
raise ValueError( raise ValueError(
f"Invalid disaggregation_mode '{config.disaggregation_mode.value}'" f"Invalid disaggregation_mode '{config.disaggregation_mode.value}'"
) )
encoder_cache = None
if config.encoder_cache_capacity_gb > 0:
capacity_bytes = int(config.encoder_cache_capacity_gb * 1024**3)
encoder_cache = EncoderCacheManager(capacity_bytes)
if config.disaggregation_mode.value == "prefill": if config.disaggregation_mode.value == "prefill":
encoder_cache = None
if config.encoder_cache_capacity_gb > 0:
# Create encoder cache for prefill handler
capacity_bytes = int(config.encoder_cache_capacity_gb * 1024**3)
encoder_cache = EncoderCacheManager(capacity_bytes)
return PrefillHandler(config, encoder_cache=encoder_cache) return PrefillHandler(config, encoder_cache=encoder_cache)
if config.disaggregation_mode.value == "prefill_and_decode":
return AggregatedHandler(config, encoder_cache=encoder_cache)
return self.handlers[config.disaggregation_mode.value](config) return self.handlers[config.disaggregation_mode.value](config)
...@@ -45,21 +47,6 @@ def get_request_handler(config: RequestHandlerConfig) -> HandlerBase: ...@@ -45,21 +47,6 @@ def get_request_handler(config: RequestHandlerConfig) -> HandlerBase:
return RequestHandlerFactory().get_request_handler(config) return RequestHandlerFactory().get_request_handler(config)
class AggregatedHandler(HandlerBase):
"""
Handler for the aggregated mode.
"""
def __init__(self, config: RequestHandlerConfig):
super().__init__(config)
async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}")
# Implement all steps locally.
async for res in self.generate_locally(request, context):
yield res
class EncodeHandler(HandlerBase): class EncodeHandler(HandlerBase):
""" """
Handler for the encode mode. Handler for the encode mode.
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for AggregatedHandler."""
import pytest
import torch
from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.trtllm.request_handlers.aggregated_handler import AggregatedHandler
from dynamo.trtllm.tests.request_handlers.utils import (
create_mock_encoder_cache,
run_generate_with_mock_fetch,
setup_multimodal_config,
)
from dynamo.trtllm.tests.utils import create_mock_request_handler_config
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.trtllm,
pytest.mark.gpu_0,
]
FETCH_PATCH_PATH = (
"dynamo.trtllm.request_handlers.aggregated_handler.fetch_embeddings_from_encoder"
)
class TestAggregatedHandlerGenerate:
"""Tests for AggregatedHandler.generate method."""
@pytest.mark.asyncio
async def test_embeddings_passed_to_generate_locally(self):
"""Cache path: List[Tensor] passed as embeddings."""
config = create_mock_request_handler_config(
disaggregation_mode="prefill_and_decode"
)
setup_multimodal_config(config, ["http://example.com/image.jpg"])
handler = AggregatedHandler(config, encoder_cache=create_mock_encoder_cache())
expected_embeddings = [torch.randn(10, 256)]
embeddings, ep_params = await run_generate_with_mock_fetch(
handler, FETCH_PATCH_PATH, expected_embeddings
)
assert embeddings is expected_embeddings
assert ep_params is None
@pytest.mark.asyncio
async def test_disaggregated_params_passed_to_generate_locally(self):
"""No-cache path: DisaggregatedParams passed as ep_params."""
config = create_mock_request_handler_config(
disaggregation_mode="prefill_and_decode"
)
setup_multimodal_config(config, ["http://example.com/image.jpg"])
handler = AggregatedHandler(config, encoder_cache=None)
expected_params = DisaggregatedParams(request_type="context_only")
embeddings, ep_params = await run_generate_with_mock_fetch(
handler, FETCH_PATCH_PATH, expected_params
)
assert embeddings is None
assert ep_params is expected_params
...@@ -3,14 +3,16 @@ ...@@ -3,14 +3,16 @@
"""Unit tests for PrefillHandler.""" """Unit tests for PrefillHandler."""
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import torch import torch
from tensorrt_llm.llmapi import DisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams
from dynamo.trtllm.request_handlers.handlers import PrefillHandler from dynamo.trtllm.request_handlers.handlers import PrefillHandler
from dynamo.trtllm.tests.request_handlers.utils import (
create_mock_encoder_cache,
run_generate_with_mock_fetch,
setup_multimodal_config,
)
from dynamo.trtllm.tests.utils import create_mock_request_handler_config from dynamo.trtllm.tests.utils import create_mock_request_handler_config
pytestmark = [ pytestmark = [
...@@ -20,125 +22,56 @@ pytestmark = [ ...@@ -20,125 +22,56 @@ pytestmark = [
pytest.mark.gpu_0, pytest.mark.gpu_0,
] ]
FETCH_PATCH_PATH = (
@pytest.fixture "dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder"
def mock_config(): )
"""Create a mock RequestHandlerConfig."""
return create_mock_request_handler_config(disaggregation_mode="prefill")
@pytest.fixture
def mock_encoder_cache():
"""Create a mock EncoderCacheManager."""
cache = MagicMock()
cache.get = MagicMock(return_value=None)
cache.set = MagicMock(return_value=True)
return cache
@pytest.fixture
def mock_context():
"""Create a mock Context."""
ctx = MagicMock()
ctx.id = MagicMock(return_value="test-id")
ctx.is_stopped = MagicMock(return_value=False)
ctx.is_killed = MagicMock(return_value=False)
return ctx
@pytest.fixture
def image_request() -> dict[str, Any]:
"""Create a request with one image URL."""
return {
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": "http://example.com/image.jpg"},
},
],
}
]
}
def setup_multimodal_config(mock_config):
"""Configure mock_config for multimodal requests."""
mock_config.multimodal_processor = MagicMock()
mock_config.multimodal_processor.extract_prompt_and_media = MagicMock(
return_value=("text", ["http://example.com/image.jpg"], [])
)
mock_config.encode_client = MagicMock()
class TestPrefillHandlerInit: class TestPrefillHandlerInit:
"""Tests for PrefillHandler initialization.""" """Tests for PrefillHandler initialization."""
def test_init_with_encoder_cache(self, mock_config, mock_encoder_cache): def test_init_with_encoder_cache(self):
"""Test PrefillHandler can be initialized with encoder_cache.""" """Test PrefillHandler can be initialized with encoder_cache."""
handler = PrefillHandler(mock_config, encoder_cache=mock_encoder_cache) config = create_mock_request_handler_config(disaggregation_mode="prefill")
cache = create_mock_encoder_cache()
handler = PrefillHandler(config, encoder_cache=cache)
assert handler.engine == mock_config.engine assert handler.engine == config.engine
assert handler._encoder_cache == mock_encoder_cache assert handler._encoder_cache == cache
class TestPrefillHandlerGenerate: class TestPrefillHandlerGenerate:
"""Tests for PrefillHandler.generate method.""" """Tests for PrefillHandler.generate method."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_embeddings_passed_to_generate_locally( async def test_embeddings_passed_to_generate_locally(self):
self, mock_config, mock_encoder_cache, mock_context, image_request """Cache path: List[Tensor] passed as embeddings."""
): config = create_mock_request_handler_config(disaggregation_mode="prefill")
"""Test embeddings from fetch_embeddings_from_encoder passed to generate_locally.""" setup_multimodal_config(config, ["http://example.com/image.jpg"])
setup_multimodal_config(mock_config) handler = PrefillHandler(config, encoder_cache=create_mock_encoder_cache())
handler = PrefillHandler(mock_config, encoder_cache=mock_encoder_cache)
expected_embeddings = [torch.randn(10, 256)] expected_embeddings = [torch.randn(10, 256)]
captured_embeddings = None
async def mock_generate_locally(request, context, embeddings, ep_params): embeddings, ep_params = await run_generate_with_mock_fetch(
nonlocal captured_embeddings handler, FETCH_PATCH_PATH, expected_embeddings
captured_embeddings = embeddings )
yield {"result": "mock"}
with patch( assert embeddings is expected_embeddings
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder", assert ep_params is None
new_callable=AsyncMock,
return_value=expected_embeddings,
) as mock_fetch:
with patch.object(handler, "generate_locally", mock_generate_locally):
async for _ in handler.generate(image_request, mock_context):
pass
mock_fetch.assert_called_once()
assert captured_embeddings is expected_embeddings
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_disaggregated_params_passed_to_generate_locally( async def test_disaggregated_params_passed_to_generate_locally(self):
self, mock_config, mock_context, image_request """No-cache path: DisaggregatedParams passed as ep_params."""
): config = create_mock_request_handler_config(disaggregation_mode="prefill")
"""Test DisaggregatedParams from fetch_embeddings_from_encoder passed to generate_locally.""" setup_multimodal_config(config, ["http://example.com/image.jpg"])
setup_multimodal_config(mock_config) handler = PrefillHandler(config, encoder_cache=None)
handler = PrefillHandler(mock_config, encoder_cache=None)
expected_params = DisaggregatedParams(request_type="context_only") expected_params = DisaggregatedParams(request_type="context_only")
captured_ep_params = None
embeddings, ep_params = await run_generate_with_mock_fetch(
async def mock_generate_locally(request, context, embeddings, ep_params): handler, FETCH_PATCH_PATH, expected_params
nonlocal captured_ep_params )
captured_ep_params = ep_params
yield {"result": "mock"} assert embeddings is None
assert ep_params is expected_params
with patch(
"dynamo.trtllm.request_handlers.handlers.fetch_embeddings_from_encoder",
new_callable=AsyncMock,
return_value=expected_params,
) as mock_fetch:
with patch.object(handler, "generate_locally", mock_generate_locally):
async for _ in handler.generate(image_request, mock_context):
pass
mock_fetch.assert_called_once()
assert captured_ep_params is expected_params
...@@ -29,6 +29,14 @@ def mock_config(): ...@@ -29,6 +29,14 @@ def mock_config():
class TestRequestHandlerFactory: class TestRequestHandlerFactory:
"""Tests for RequestHandlerFactory.""" """Tests for RequestHandlerFactory."""
def test_invalid_mode_raises(self, mock_config):
"""Test factory raises ValueError for invalid disaggregation_mode."""
mock_config.disaggregation_mode.value = "invalid_mode"
factory = RequestHandlerFactory()
with pytest.raises(ValueError, match="Invalid disaggregation_mode"):
factory.get_request_handler(mock_config)
def test_creates_aggregated_handler(self, mock_config): def test_creates_aggregated_handler(self, mock_config):
"""Test factory creates AggregatedHandler for prefill_and_decode mode.""" """Test factory creates AggregatedHandler for prefill_and_decode mode."""
factory = RequestHandlerFactory() factory = RequestHandlerFactory()
...@@ -44,14 +52,6 @@ class TestRequestHandlerFactory: ...@@ -44,14 +52,6 @@ class TestRequestHandlerFactory:
assert isinstance(handler, PrefillHandler) assert isinstance(handler, PrefillHandler)
def test_invalid_mode_raises(self, mock_config):
"""Test factory raises ValueError for invalid disaggregation_mode."""
mock_config.disaggregation_mode.value = "invalid_mode"
factory = RequestHandlerFactory()
with pytest.raises(ValueError, match="Invalid disaggregation_mode"):
factory.get_request_handler(mock_config)
def test_prefill_handler_with_encoder_cache(self): def test_prefill_handler_with_encoder_cache(self):
"""Test factory creates PrefillHandler with EncoderCacheManager when capacity > 0.""" """Test factory creates PrefillHandler with EncoderCacheManager when capacity > 0."""
mock_config = create_mock_request_handler_config( mock_config = create_mock_request_handler_config(
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared test utilities for request handler tests."""
from typing import Any, List, Tuple
from unittest.mock import AsyncMock, MagicMock, patch
def create_mock_encoder_cache() -> MagicMock:
"""Create mock EncoderCacheManager."""
cache = MagicMock()
cache.get = MagicMock(return_value=None)
cache.set = MagicMock(return_value=True)
return cache
def create_mock_context(request_id: str = "test-id") -> MagicMock:
"""Create mock Context."""
ctx = MagicMock()
ctx.id = MagicMock(return_value=request_id)
ctx.is_stopped = MagicMock(return_value=False)
ctx.is_killed = MagicMock(return_value=False)
return ctx
def setup_multimodal_config(config: MagicMock, image_urls: List[str]) -> None:
"""Configure multimodal_processor and encode_client on config."""
config.multimodal_processor = MagicMock()
config.multimodal_processor.extract_prompt_and_media = MagicMock(
return_value=("text", image_urls, [])
)
config.encode_client = MagicMock()
async def run_generate_with_mock_fetch(
handler: Any,
fetch_patch_path: str,
mock_return_value: Any,
) -> Tuple[Any, Any]:
"""
Run handler.generate() with mocked fetch_embeddings_from_encoder.
Args:
handler: Handler instance (PrefillHandler or AggregatedHandler)
fetch_patch_path: Full path to patch fetch_embeddings_from_encoder
mock_return_value: Value to return from mocked fetch
Returns:
Tuple of (captured_embeddings, captured_ep_params)
"""
captured_embeddings = None
captured_ep_params = None
async def mock_generate_locally(request, context, embeddings, ep_params):
nonlocal captured_embeddings, captured_ep_params
captured_embeddings = embeddings
captured_ep_params = ep_params
yield {"result": "mock"}
request: dict[str, Any] = {"messages": []}
with patch(
fetch_patch_path,
new_callable=AsyncMock,
return_value=mock_return_value,
) as mock_fetch:
with patch.object(handler, "generate_locally", mock_generate_locally):
async for _ in handler.generate(request, create_mock_context()):
pass
mock_fetch.assert_called_once()
return captured_embeddings, captured_ep_params
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 8192
max_batch_size: 16
trust_remote_code: true
backend: pytorch
enable_chunked_prefill: true
kv_cache_config:
free_gpu_memory_fraction: 0.60
enable_block_reuse: false
cache_transceiver_config:
backend: DEFAULT
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# 1 Encode + 1 PD worker for llava-v1.6-mistral-7b-hf
# GPU 0: Encode (vision encoder)
# GPU 1: PD worker (prefill + decode, TP=1)
# Environment variables with defaults
export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"}
export MODEL_PATH=${MODEL_PATH:-"llava-hf/llava-v1.6-mistral-7b-hf"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"llava-v1.6-mistral-7b-hf"}
export ENCODE_ENGINE_ARGS=${ENCODE_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/encode.yaml"}
export PD_ENGINE_ARGS=${PD_ENGINE_ARGS:-"$DYNAMO_HOME/examples/backends/trtllm/engine_configs/llava-v1.6-mistral-7b-hf/agg.yaml"}
export ENCODE_CUDA_VISIBLE_DEVICES=${ENCODE_CUDA_VISIBLE_DEVICES:-"0"}
export ENCODE_ENDPOINT=${ENCODE_ENDPOINT:-"dyn://dynamo.tensorrt_llm_encode.generate"}
export MODALITY=${MODALITY:-"multimodal"}
export ALLOWED_LOCAL_MEDIA_PATH=${ALLOWED_LOCAL_MEDIA_PATH:-"/tmp"}
export MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
export DYN_ENCODER_CACHE_CAPACITY_GB=${DYN_ENCODER_CACHE_CAPACITY_GB:-4}
export CUSTOM_TEMPLATE=${CUSTOM_TEMPLATE:-"$DYNAMO_HOME/examples/backends/trtllm/templates/llava_multimodal.jinja"}
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $ENCODE_PID $PD_PID_1 2>/dev/null || true
wait $DYNAMO_PID $ENCODE_PID $PD_PID_1 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run frontend
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python3 -m dynamo.frontend &
DYNAMO_PID=$!
# run encode worker (vision encoder on GPU 0)
CUDA_VISIBLE_DEVICES=$ENCODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$ENCODE_ENGINE_ARGS" \
--modality "$MODALITY" \
--allowed-local-media-path "$ALLOWED_LOCAL_MEDIA_PATH" \
--max-file-size-mb "$MAX_FILE_SIZE_MB" \
--disaggregation-mode encode &
ENCODE_PID=$!
# run PD worker 1 (GPU 1)
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PD_ENGINE_ARGS" \
--modality "$MODALITY" \
--custom-jinja-template "$CUSTOM_TEMPLATE" \
--encode-endpoint "$ENCODE_ENDPOINT" \
--dyn-encoder-cache-capacity-gb "$DYN_ENCODER_CACHE_CAPACITY_GB" &
PD_PID_1=$!
wait $DYNAMO_PID
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment