Unverified Commit b5c0db63 authored by Karen Chung's avatar Karen Chung Committed by GitHub
Browse files

feat: TRTLLM DP Rank Routing (#5936)

parent 04fe92d7
......@@ -90,6 +90,16 @@ class TensorRTLLMEngine:
raise RuntimeError("Engine not initialized")
return self._llm
def get_attention_dp_size(self) -> int:
"""Return attention_dp_size (tensor_parallel_size if attention DP enabled, else 1).
When attention DP is enabled, each attention DP rank becomes a separate routing target.
"""
if not self._llm:
return 1
enable_attention_dp = getattr(self.llm.args, "enable_attention_dp", False)
tensor_parallel_size = getattr(self.llm.args, "tensor_parallel_size", 1)
return tensor_parallel_size if enable_attention_dp else 1
@staticmethod
def _prune_engine_args_for_autodeploy(engine_args) -> None:
"""Remove entries from `self.engine_args` that the autodeploy backend does not support."""
......@@ -110,7 +120,7 @@ class TensorRTLLMEngine:
"moe_cluster_parallel_size",
"moe_tensor_parallel_size",
"moe_expert_parallel_size",
"enable_attention_dp",
"enable_attention_dp", # AutoDeploy doesn't support attention DP (only pytorch backend does)
"cp_config",
]
for field_name in unsupported_fields:
......
......@@ -219,6 +219,7 @@ async def init(
"tensor_parallel_size": config.tensor_parallel_size,
"pipeline_parallel_size": config.pipeline_parallel_size,
"moe_expert_parallel_size": config.expert_parallel_size,
"enable_attention_dp": config.enable_attention_dp,
"backend": Backend.PYTORCH,
"kv_cache_config": kv_cache_config,
"gpus_per_node": gpus_per_node,
......@@ -391,11 +392,17 @@ async def init(
runtime_config.reasoning_parser = config.reasoning_parser
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.enable_local_indexer = config.enable_local_indexer
# Set data_parallel_size for attention DP mode
# This enables the router's scheduler to correctly iterate over all dp_ranks
# Need to name ADP as `data_parallel_size` for parity with other frameworks
attention_dp_size = engine.get_attention_dp_size()
runtime_config.data_parallel_size = attention_dp_size
logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}")
logging.info(
f"Set runtime config max_num_batched_tokens: {runtime_config.max_num_batched_tokens}"
)
logging.info(f"Set runtime config data_parallel_size: {attention_dp_size}")
# The get_engine_runtime_config function exists but is not called here due to:
# 1. get_stats_async requires active requests to work properly
......
......@@ -28,7 +28,7 @@ import traceback
import weakref
from contextlib import asynccontextmanager
from queue import Queue
from typing import Awaitable, Callable, Optional, Union
from typing import Awaitable, Callable, Dict, Optional, Union
import msgpack
import zmq
......@@ -72,6 +72,8 @@ class ZmqKvEventPublisher:
Event Format: [timestamp, [events], data_parallel_rank]
Message Format: multipart ZMQ message [topic, sequence, payload] where payload is
msgpack-serialized batch.
When attention DP is enabled for DeepSeek-style models, `data_parallel_rank` is set to the attention DP rank.
Otherwise, it defaults to 0.
Usage:
Used by Publisher class when consolidator is enabled (zmq_endpoint provided).
......@@ -94,7 +96,9 @@ class ZmqKvEventPublisher:
self.socket = self.ctx.socket(zmq.PUB)
self.socket.bind(zmq_endpoint)
self.sequence = 0
self.data_parallel_rank = 0 # TensorRT-LLM doesn't use DP for now
self.data_parallel_rank = (
0 # TensorRT-LLM doesn't use DP for now (but does support attention DP)
)
logging.info(
f"TensorRT-LLM: ZMQ KV event publisher initialized - bound to {zmq_endpoint} "
f"with topic '{topic}', kv_block_size={kv_block_size}"
......@@ -107,6 +111,7 @@ class ZmqKvEventPublisher:
block_hashes: list[int],
lora_id: int = 0,
parent_hash: Optional[int] = None,
attention_dp_rank: int = 0,
):
"""Publish a BlockStored event.
......@@ -129,9 +134,9 @@ class ZmqKvEventPublisher:
"lora_id": lora_id if lora_id != 0 else None,
}
self._publish_event(event)
self._publish_event(event, attention_dp_rank)
def publish_removed(self, block_hashes: list[int]):
def publish_removed(self, block_hashes: list[int], attention_dp_rank: int = 0):
"""Publish a BlockRemoved event.
Note: event_id is managed internally via self.sequence counter.
......@@ -144,22 +149,23 @@ class ZmqKvEventPublisher:
"block_hashes": block_hashes_signed,
}
self._publish_event(event)
self._publish_event(event, attention_dp_rank)
def publish_all_cleared(self):
"""Publish an AllBlocksCleared event."""
event = {"type": "AllBlocksCleared"}
self._publish_event(event)
def _publish_event(self, event: dict):
def _publish_event(self, event: dict, attention_dp_rank: int = 0):
"""Publish a single event to ZMQ in vLLM batch format."""
try:
# Create batch in vLLM format: [timestamp, [events], data_parallel_rank]
# The third element (data_parallel_rank) is used by the router for dp_rank routing
timestamp = time.time()
batch = [timestamp, [event], self.data_parallel_rank]
batch = [timestamp, [event], attention_dp_rank]
event_type = event.get("type", "Unknown")
logging.debug(
f"TensorRT-LLM: ZMQ publisher sending {event_type} event to {self.zmq_endpoint}"
f"TensorRT-LLM: ZMQ publisher sending {event_type} event (dp_rank={attention_dp_rank}) to {self.zmq_endpoint}"
)
# Serialize with msgpack (vLLM uses msgpack/rmp_serde compatible format)
......@@ -295,6 +301,7 @@ class Publisher:
self.max_window_size = None
self.metrics_labels = metrics_labels
self.enable_local_indexer = enable_local_indexer
self.attention_dp_size = engine.get_attention_dp_size()
# The first few kv events from the model engine are always "created" type events.
# Use these events to capture the max_window_size of the model.
......@@ -303,7 +310,9 @@ class Publisher:
# Needed by the events and metrics publishers
self.metrics_publisher = None
self.kv_event_publisher = None
self.kv_event_publishers: Optional[
Dict[int, KvEventPublisher]
] = None # One per attention_dp_rank
self.zmq_kv_event_publisher = None # ZMQ publisher for consolidator
self.publish_kv_cache_events_thread: Optional[ManagedThread] = None
self.publish_stats_thread: Optional[ManagedThread] = None
......@@ -355,15 +364,21 @@ class Publisher:
"KV Event Consolidator enabled - using ZMQ publisher only. "
"Consolidator will publish consolidated events to NATS."
)
self.kv_event_publisher = None
self.kv_event_publishers = None
else:
# No consolidator: use NATS publisher (router subscribes directly)
self.kv_event_publisher = KvEventPublisher(
self.kv_listener,
self.worker_id,
self.kv_block_size,
dp_rank=0,
enable_local_indexer=self.enable_local_indexer,
# Create one KvEventPublisher per attention_dp_rank (similar to vLLM's DP pattern)
self.kv_event_publishers = {}
for rank in range(self.attention_dp_size):
self.kv_event_publishers[rank] = KvEventPublisher(
self.kv_listener,
self.worker_id,
self.kv_block_size,
dp_rank=rank,
enable_local_indexer=self.enable_local_indexer,
)
logging.info(
f"Created {self.attention_dp_size} KV event publisher(s) for attention DP ranks"
)
# Always initialize the thread - it routes to either ZMQ or NATS publisher
......@@ -461,7 +476,7 @@ class Publisher:
return
# Check that at least one publisher is available
if self.kv_event_publisher is None and self.zmq_kv_event_publisher is None:
if not self.kv_event_publishers and self.zmq_kv_event_publisher is None:
logging.error("No KV event publisher initialized (neither NATS nor ZMQ)!")
return
......@@ -529,8 +544,12 @@ class Publisher:
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
# Get attention_dp_rank from event (TRT-LLM includes this in KVCacheEvent)
# Default to 0 for backwards compatibility with older TRT-LLM versions
attention_dp_rank = event.get("attention_dp_rank", 0)
logging.debug(
f"publish stored event: engine_event_id: {event_id}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
f"publish stored event: engine_event_id: {event_id}, attention_dp_rank: {attention_dp_rank}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
)
# Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
# Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
......@@ -542,16 +561,25 @@ class Publisher:
block_hashes,
lora_id,
parent_hash,
attention_dp_rank,
)
elif self.kv_event_publisher:
elif self.kv_event_publishers:
# No consolidator: publish to NATS (router subscribes directly)
self.kv_event_publisher.publish_stored(
token_ids,
num_block_tokens,
block_hashes,
lora_id,
parent_hash,
)
# Route to correct publisher based on attention_dp_rank
publisher = self.kv_event_publishers.get(attention_dp_rank)
if publisher:
publisher.publish_stored(
token_ids,
num_block_tokens,
block_hashes,
lora_id,
parent_hash,
)
else:
logging.warning(
f"No publisher for attention_dp_rank={attention_dp_rank}, "
f"available ranks: {list(self.kv_event_publishers.keys())}"
)
elif data["type"] == "removed":
self.processing_initial_created_events = False
removed_block_hashes: list[int] = []
......@@ -567,17 +595,30 @@ class Publisher:
continue
removed_block_hashes.append(block_hash)
# Get attention_dp_rank from event (TRT-LLM includes this in KVCacheEvent)
attention_dp_rank = event.get("attention_dp_rank", 0)
logging.debug(
f"publish removed event: engine_event_id: {event_id}, block_hashes: {removed_block_hashes}"
f"publish removed event: engine_event_id: {event_id}, attention_dp_rank: {attention_dp_rank}, block_hashes: {removed_block_hashes}"
)
# Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
# Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
if self.zmq_kv_event_publisher:
# Consolidator enabled: publish to ZMQ only
self.zmq_kv_event_publisher.publish_removed(removed_block_hashes)
elif self.kv_event_publisher:
self.zmq_kv_event_publisher.publish_removed(
removed_block_hashes, attention_dp_rank
)
elif self.kv_event_publishers:
# No consolidator: publish to NATS (router subscribes directly)
self.kv_event_publisher.publish_removed(removed_block_hashes)
# Route to correct publisher based on attention_dp_rank
publisher = self.kv_event_publishers.get(attention_dp_rank)
if publisher:
publisher.publish_removed(removed_block_hashes)
else:
logging.warning(
f"No publisher for attention_dp_rank={attention_dp_rank}, "
f"available ranks: {list(self.kv_event_publishers.keys())}"
)
elif data["type"] == "created" and self.processing_initial_created_events:
self.update_max_window_size(event)
......
......@@ -26,6 +26,7 @@ from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.executor.utils import RequestError
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.scheduling_params import SchedulingParams
from dynamo._core import Context
from dynamo.common.utils.otel_tracing import build_trace_headers
......@@ -680,6 +681,19 @@ class HandlerBase:
# Build trace headers for distributed tracing
trace_headers = build_trace_headers(context)
# Extract dp_rank from request's routing hints for attention DP routing
routing = request.get("routing", {})
dp_rank = routing.get("dp_rank") if routing else None
scheduling_params = None
if dp_rank is not None:
scheduling_params = SchedulingParams(
attention_dp_rank=dp_rank,
attention_dp_relax=False, # Strict routing - use the rank dynamo router selected
)
logging.debug(
f"Using dynamo router dp_rank={dp_rank} for TRTLLM attention DP scheduling"
)
try:
# NEW: Updated engine call to include multimodal data
generation_result = self.engine.llm.generate_async(
......@@ -688,6 +702,7 @@ class HandlerBase:
disaggregated_params=disaggregated_params,
streaming=streaming,
trace_headers=trace_headers,
scheduling_params=scheduling_params,
)
# Monitor for cancellation triggers and cancel by calling generation_result.abort()
......
......@@ -56,7 +56,10 @@ class TestTensorRTLLMEngine:
({"moe_cluster_parallel_size": 3}, True),
({"moe_tensor_parallel_size": 3}, True),
({"moe_expert_parallel_size": 3}, True),
({"enable_attention_dp": True}, True),
(
{"enable_attention_dp": True},
True,
), # AutoDeploy doesn't support attention DP
# Default value is an empty dict.
({"cp_config": {"foo", "bar"}}, True),
({"scheduler_config": {}}, False),
......
......@@ -38,6 +38,7 @@ class Config:
self.tensor_parallel_size: int = 1
self.pipeline_parallel_size: int = 1
self.expert_parallel_size: Optional[int] = None
self.enable_attention_dp: bool = False
self.kv_block_size: int = 32
self.migration_limit: int = 0
self.gpus_per_node: Optional[int] = None
......@@ -77,6 +78,7 @@ class Config:
f"tensor_parallel_size={self.tensor_parallel_size}, "
f"pipeline_parallel_size={self.pipeline_parallel_size}, "
f"expert_parallel_size={self.expert_parallel_size}, "
f"enable_attention_dp={self.enable_attention_dp}, "
f"kv_block_size={self.kv_block_size}, "
f"gpus_per_node={self.gpus_per_node}, "
f"max_batch_size={self.max_batch_size}, "
......@@ -183,6 +185,11 @@ def cmd_line_args():
default=None,
help="expert parallelism size.",
)
parser.add_argument(
"--enable-attention-dp",
action="store_true",
help="Enable attention data parallelism. When enabled, attention_dp_size equals tensor_parallel_size.",
)
# IMPORTANT: We should ideally not expose this to users. We should be able to
# query the block size from the TRTLLM engine.
......@@ -399,6 +406,7 @@ def cmd_line_args():
config.pipeline_parallel_size = args.pipeline_parallel_size
if args.expert_parallel_size is not None:
config.expert_parallel_size = args.expert_parallel_size
config.enable_attention_dp = args.enable_attention_dp
if args.gpus_per_node is not None:
config.gpus_per_node = args.gpus_per_node
if args.free_gpu_memory_fraction is not None:
......
......@@ -43,6 +43,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
- [Benchmarking](#benchmarking)
- [Multimodal Support](#multimodal-support)
- [Logits Processing](#logits-processing)
- [DP Rank Routing](#dp-rank-routing-attention-data-parallelism)
- [Performance Sweep](#performance-sweep)
- [Known Issues and Mitigations](#known-issues-and-mitigations)
......@@ -289,6 +290,35 @@ sampling_params.logits_processor = create_trtllm_adapters(processors)
- Processors must modify logits in-place and not return a new tensor.
- If your processor needs tokenization, ensure the tokenizer is initialized (do not skip tokenizer init).
## DP Rank Routing (Attention Data Parallelism)
TensorRT-LLM supports [attention data parallelism](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) (attention DP) for models like DeepSeek. When enabled, multiple attention DP ranks run within a single worker, each with its own KV cache. Dynamo can route requests to specific DP ranks based on KV cache state.
### Dynamo vs TRT-LLM Internal Routing
- **Dynamo DP Rank Routing**: The router selects the optimal DP rank based on KV cache overlap and instructs TRT-LLM to use that rank with strict routing (`attention_dp_relax=False`). Use this with `--router-mode kv` for cache-aware routing.
- **TRT-LLM Internal Routing**: TRT-LLM's scheduler assigns DP ranks internally. Use this with `--router-mode round-robin` or `random` when KV-aware routing isn't needed.
### Enabling DP Rank Routing
```bash
# Worker with attention DP
# (TP=2 acts as the "world size", in effect creating 2 attention DP ranks)
CUDA_VISIBLE_DEVICES=0,1 python3 -m dynamo.trtllm \
--model-path <MODEL_PATH> \
--tensor-parallel-size 2 \
--enable-attention-dp \
--publish-events-and-metrics
# Frontend with KV routing
python3 -m dynamo.frontend --router-mode kv
```
The `--enable-attention-dp` flag sets `attention_dp_size = tensor_parallel_size` and configures Dynamo to publish KV events per DP rank. The router automatically creates routing targets for each `(worker_id, dp_rank)` combination.
> [!NOTE]
> Attention DP requires TRT-LLM's PyTorch backend. AutoDeploy does not support attention DP.
## Performance Sweep
For detailed instructions on running comprehensive performance sweeps across both aggregated and disaggregated serving configurations, see the [TensorRT-LLM Benchmark Scripts for DeepSeek R1 model](../../../examples/backends/trtllm/performance_sweeps/README.md). This guide covers recommended benchmarking setups, usage of provided scripts, and best practices for evaluating system performance.
......
......@@ -95,15 +95,21 @@ class TRTLLMProcess:
- model: Model name/path (default: TinyLlama-1.1B)
- free_gpu_memory_fraction: Fraction of GPU memory to allocate (optional)
- max_seq_len: Maximum sequence length (optional)
- tensor_parallel_size: Number of GPUs for tensor parallelism (optional).
When attention DP is enabled, this sets the world size, which then is the attention_dp_size.
- enable_attention_dp: If True, enable TRT-LLM attention data parallelism.
When enabled, attention_dp_size equals tensor_parallel_size, creating
multiple routing targets within a single TRT-LLM worker process.
num_workers: Number of TRT-LLM worker processes
single_gpu: If True, all workers share GPU 0
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
enable_local_indexer: If True, enable worker-local KV indexer for NATS Core mode. Defaults to False.
Note: TRT-LLM doesn't support data parallelism like vLLM (dp_rank is always 0).
Tensor parallelism (TP) is supported but creates 1 worker spanning multiple GPUs,
not multiple routing targets.
Note: TRT-LLM supports two forms of parallelism for routing:
1. Multiple workers (num_workers > 1): Each worker is a separate routing target
2. Attention DP (enable_attention_dp=True in trtllm_args): Single worker with
multiple internal attention DP ranks, each being a separate routing target
"""
# Generate unique namespace for isolation
namespace_suffix = generate_random_suffix()
......@@ -120,6 +126,8 @@ class TRTLLMProcess:
model = trtllm_args.get("model", MODEL_NAME)
free_gpu_memory_fraction = trtllm_args.get("free_gpu_memory_fraction")
max_seq_len = trtllm_args.get("max_seq_len")
enable_attention_dp = trtllm_args.get("enable_attention_dp", False)
tensor_parallel_size = trtllm_args.get("tensor_parallel_size")
self.model_name = model
......@@ -128,6 +136,10 @@ class TRTLLMProcess:
if single_gpu:
# Force all processes to GPU 0 (for single-GPU testing)
gpu_device = "0"
elif enable_attention_dp and tensor_parallel_size:
# For attention DP, TRT-LLM spawns tensor_parallel_size internal MPI workers.
# So one process = two attention DP ranks = visibility in to both GPUs.
gpu_device = ",".join(str(i) for i in range(tensor_parallel_size))
else:
# Each worker sees one GPU
gpu_device = str(worker_idx)
......@@ -156,6 +168,14 @@ class TRTLLMProcess:
if max_seq_len is not None:
command.extend(["--max-seq-len", str(max_seq_len)])
# Set tensor parallel size if specified (needed for attention DP)
if tensor_parallel_size is not None:
command.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
# Enable attention data parallelism if requested
if enable_attention_dp:
command.append("--enable-attention-dp")
# Each TRT-LLM worker needs a unique DYN_SYSTEM_PORT to avoid conflicts.
# See examples/backends/trtllm/launch/disagg_same_gpu.sh for reference.
system_port = 8081 + worker_idx
......@@ -344,6 +364,70 @@ def test_trtllm_kv_router_basic(
trtllm_workers.__exit__(None, None, None)
@pytest.mark.gpu_2
@pytest.mark.nightly
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
@pytest.mark.timeout(600) # 10 min max (multi-GPU + DP startup variance)
def test_router_decisions_trtllm_attention_dp(
request,
runtime_services_dynamic_ports,
predownload_models,
set_ucx_tls_no_mm,
request_plane,
):
"""Validate KV cache prefix reuse with TRTLLM by sending progressive requests with overlapping prefixes.
Same flow as test_router_decisions_trtllm_multiple_workers; force first request to (worker_id, dp_rank=1).
Dump events from router and verify:
* All but one (worker_id, dp_rank) should have no events (due to prefix reuse)
* The (worker_id, dp_rank) with events should have exactly 4 events (one per request)
* All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse)
"""
N_TRTLLM_WORKERS = 1
N_ATTENTION_DP_RANKS = 2
# Create trtllm_args with attention DP enabled
TRTLLM_ADP_ARGS = {
**TRTLLM_ARGS,
"enable_attention_dp": True,
"tensor_parallel_size": N_ATTENTION_DP_RANKS,
}
try:
logger.info(
f"Starting 1 TRT-LLM worker with attention DP enabled (attention_dp_size={N_ATTENTION_DP_RANKS})"
)
trtllm_workers = TRTLLMProcess(
request,
trtllm_args=TRTLLM_ADP_ARGS,
num_workers=N_TRTLLM_WORKERS,
single_gpu=False,
request_plane=request_plane,
)
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
trtllm_workers.__enter__()
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the vLLM workers
namespace = runtime.namespace(trtllm_workers.namespace)
component = namespace.component("tensorrt_llm")
endpoint = component.endpoint("generate")
_test_router_decisions(
trtllm_workers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
block_size=TRTLLM_BLOCK_SIZE,
)
finally:
# Clean up TRTLLM workers
if "trtllm_workers" in locals():
trtllm_workers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.gpu_1
@pytest.mark.parametrize("request_plane", ["tcp"], indirect=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment