Unverified Commit 9fa8125c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: trtllm use unified frontend (#4097)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 427ca9ab
......@@ -17,6 +17,13 @@ This directory contains scripts for benchmarking the Dynamo router with prefix c
- `matplotlib` for plotting results
- `data-generator` package (install with `pip install -e ./benchmarks` from repo root)
> [!Note]
> If running outside a container, set `DYNAMO_HOME` to the root path of your Dynamo repository:
> ```bash
> export DYNAMO_HOME=/path/to/dynamo
> ```
> When running in a container, this defaults to `/workspace`.
### Setting up etcd and NATS
This benchmark requires etcd and NATS. To quickly set them up, run:
......
......@@ -225,7 +225,7 @@ else
if [ "$USE_TRTLLM" = true ]; then
echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
# Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization
# Run TensorRT-LLM engine
TRTLLM_ARGS=()
TRTLLM_ARGS+=("--model-path" "$MODEL_PATH")
TRTLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
......@@ -234,7 +234,7 @@ else
fi
TRTLLM_ARGS+=("${EXTRA_ARGS[@]}")
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python3 -m dynamo.trtllm \
"${TRTLLM_ARGS[@]}"
else
echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
......@@ -252,12 +252,18 @@ else
fi
VLLM_ARGS+=("${EXTRA_ARGS[@]}")
exec env PYTHONHASHSEED=0 CUDA_VISIBLE_DEVICES=$GPU_DEVICES python -m dynamo.vllm \
exec env PYTHONHASHSEED=0 CUDA_VISIBLE_DEVICES=$GPU_DEVICES python3 -m dynamo.vllm \
"${VLLM_ARGS[@]}"
fi
} &
PIDS+=($!)
echo "Started $MODE worker $i (PID: $!)"
# Add delay between TensorRT-LLM worker launches to avoid MPI initialization conflicts
if [ "$USE_TRTLLM" = true ] && [ "$i" -lt "$NUM_WORKERS" ]; then
echo "Waiting 2 seconds before launching next TensorRT-LLM worker..."
sleep 2
fi
done
fi
......
......@@ -119,16 +119,14 @@ class SGLangComponentName:
class TrtllmComponentName:
# Note: Planner only supports DECODE_FIRST strategy in TRT-LLM:
# - Decode worker is the first worker (tensorrt_llm)
# - Prefill worker is the next worker (tensorrt_llm_next)
# Unified frontend architecture (consistent with vLLM/SGLang):
# - Prefill workers use "prefill" component
# - Decode workers use "tensorrt_llm" component
prefill_worker_k8s_name = "TRTLLMPrefillWorker"
prefill_worker_component_name = (
"tensorrt_llm_next" # Prefill is "next" with DECODE_FIRST
)
prefill_worker_component_name = "prefill"
prefill_worker_endpoint = "generate"
decode_worker_k8s_name = "TRTLLMDecodeWorker"
decode_worker_component_name = "tensorrt_llm" # Decode is "first" with DECODE_FIRST
decode_worker_component_name = "tensorrt_llm"
decode_worker_endpoint = "generate"
......
......@@ -45,6 +45,7 @@ from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import get_publisher
from dynamo.trtllm.request_handlers.handler_base import DisaggregationMode
from dynamo.trtllm.request_handlers.handlers import (
RequestHandlerConfig,
RequestHandlerFactory,
......@@ -53,7 +54,6 @@ from dynamo.trtllm.utils.trtllm_utils import (
Config,
cmd_line_args,
deep_update,
is_first_worker,
parse_endpoint,
)
......@@ -126,37 +126,6 @@ async def init(runtime: DistributedRuntime, config: Config):
"""
logging.info(f"Initializing the worker with config: {config}")
next_client = None
if config.next_endpoint:
logging.info(
f"Initializing next worker client for endpoint: {config.next_endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
config.next_endpoint
)
next_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
# Set up prefill router client for decode workers
next_router_client = None
if config.disaggregation_mode.value == "decode":
try:
logging.info("Initializing prefill router client")
next_router_client = (
await runtime.namespace(config.namespace)
.component("router") # Standalone router for prefill workers
.endpoint("generate")
.client()
)
logging.info("Prefill router client initialized successfully")
except Exception as e:
logging.warning(f"Failed to initialize prefill router client: {e}")
logging.info("Will use direct prefill worker client only")
encode_client = None
if config.encode_endpoint:
logging.info(
......@@ -273,7 +242,13 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
model_input = ModelInput.Tokens
# Set model type based on disaggregation mode for unified frontend support
if config.disaggregation_mode == DisaggregationMode.PREFILL:
model_type = ModelType.Prefill
else:
model_type = ModelType.Chat | ModelType.Completions
multimodal_processor = None
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
......@@ -376,9 +351,6 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params=default_sampling_params,
publisher=None,
disaggregation_mode=config.disaggregation_mode,
disaggregation_strategy=config.disaggregation_strategy,
next_client=next_client,
next_router_client=next_router_client,
encode_client=encode_client,
multimodal_processor=multimodal_processor,
connector=connector,
......@@ -386,14 +358,10 @@ async def init(runtime: DistributedRuntime, config: Config):
metrics_collector=metrics_collector,
)
if next_client:
logging.info(
f"Waiting for the next endpoint to be ready: {config.next_endpoint}"
)
await next_client.wait_for_instances()
if is_first_worker(config):
# Register the model with runtime config
# Encode workers do NOT register - they're internal workers only
# Prefill and decode workers register - frontend detects their role via ModelType
if config.disaggregation_mode != DisaggregationMode.ENCODE:
await register_llm(
model_input,
model_type,
......
......@@ -52,11 +52,6 @@ class DisaggregationMode(Enum):
ENCODE = "encode"
class DisaggregationStrategy(Enum):
PREFILL_FIRST = "prefill_first"
DECODE_FIRST = "decode_first"
@dataclass
class RequestHandlerConfig:
"""
......@@ -68,9 +63,6 @@ class RequestHandlerConfig:
default_sampling_params: SamplingParams
publisher: Publisher
disaggregation_mode: DisaggregationMode
disaggregation_strategy: DisaggregationStrategy
next_client: object
next_router_client: Optional[object] = None
encode_client: Optional[object] = None
multimodal_processor: Optional[
MultimodalRequestProcessor
......@@ -94,9 +86,6 @@ class HandlerBase:
self.publisher = config.publisher
self.metrics_collector = config.metrics_collector
self.disaggregation_mode = config.disaggregation_mode
self.disaggregation_strategy = config.disaggregation_strategy
self.next_client = config.next_client
self.next_router_client = config.next_router_client
self.encode_client = config.encode_client
self.multimodal_processor = config.multimodal_processor
self.first_generation = True
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
import logging
from dynamo._core import Context
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.encode_helper import EncodeHelper
from dynamo.trtllm.request_handlers.handler_base import (
DisaggregationMode,
DisaggregationStrategy,
HandlerBase,
RequestHandlerConfig,
)
......@@ -26,32 +23,11 @@ class RequestHandlerFactory:
"prefill_and_decode": AggregatedHandler,
}
def _validate_config(self, config: RequestHandlerConfig):
def get_request_handler(self, config: RequestHandlerConfig) -> HandlerBase:
if config.disaggregation_mode.value not in self.handlers:
raise ValueError(
f"Invalid disaggregation_mode '{config.disaggregation_mode.value}'"
)
if not config.next_client:
if (
config.disaggregation_mode == DisaggregationMode.PREFILL
and config.disaggregation_strategy
== DisaggregationStrategy.PREFILL_FIRST
):
raise ValueError(
"Next client is required for the main worker when disaggregation_mode='prefill' and disaggregation_strategy='prefill_first'."
)
if (
config.disaggregation_mode == DisaggregationMode.DECODE
and config.disaggregation_strategy
== DisaggregationStrategy.DECODE_FIRST
):
raise ValueError(
"Next client is required for the decode worker when disaggregation_mode='decode' and disaggregation_strategy='decode_first'."
)
def get_request_handler(self, config: RequestHandlerConfig) -> HandlerBase:
self._validate_config(config)
return self.handlers[config.disaggregation_mode.value](config)
......@@ -104,14 +80,14 @@ class EncodeHandler(HandlerBase):
class PrefillHandler(HandlerBase):
"""
Handler for the prefill mode.
Handler for prefill-only workers in disaggregated serving.
"""
def __init__(self, config: RequestHandlerConfig):
super().__init__(config)
async def remote_encode_with_nixl(self, request: dict):
# 2. Get response with shape info and readable metadata
# Get response with shape info and readable metadata
encode_response = None
async for res in await self.encode_client.round_robin(request):
encode_response = res.data()
......@@ -125,12 +101,12 @@ class PrefillHandler(HandlerBase):
encode_response, self.connector
)
async def remote_decode(self, request: dict, context: Context):
async for res in await self.next_client.round_robin(request, context=context):
yield res.data()
async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}")
"""
Prefill worker: process prompt and return disaggregated_params.
Frontend routes to decode workers automatically.
"""
logging.debug(f"Prefill Request ID: {context.id()}")
logging.debug(f"PrefillHandler.generate received request: {request}")
embeddings_tensor = None
......@@ -138,119 +114,48 @@ class PrefillHandler(HandlerBase):
_, _, embedding_paths = self.multimodal_processor.extract_prompt_and_media(
request.get("messages", [])
)
# This check will be removed once TRTLLM Encoder is integrated.
if embedding_paths:
if self.encode_client and self.connector:
logging.debug(
"PrefillHandler calling Encode Worker via remote_encode_with_nixl"
)
embeddings_tensor = await self.remote_encode_with_nixl(request)
# Generate the prefill response locally
prefill_request = copy.deepcopy(request)
prefill_response = None
# Generate prefill response locally and return disaggregated_params
response_count = 0
async for res in self.generate_locally(
prefill_request, context, embeddings_tensor
):
prefill_response = res
async for res in self.generate_locally(request, context, embeddings_tensor):
response_count += 1
if response_count > 1:
raise ValueError("Prefill response should be generated only once.")
if context.is_stopped() or context.is_killed():
# Local generate abort monitor will print debug log, so only returning here.
return
if (
self.disaggregation_strategy == DisaggregationStrategy.PREFILL_FIRST
and not self.check_error(prefill_response)
):
# If operating under prefill_first strategy, the prefill handler needs to trigger
# the decode handler.
if prefill_response is not None:
request["disaggregated_params"] = prefill_response[
"disaggregated_params"
]
async for res in self.remote_decode(request, context):
# Return response with disaggregated_params to frontend
yield res
if context.is_stopped() or context.is_killed():
logging.debug(f"Aborted Remote Request ID: {context.id()}")
return
else:
# Return response to the decode handler.
yield prefill_response
class DecodeHandler(HandlerBase):
"""
Handler for the decode mode.
Handler for decode-only workers in disaggregated serving.
"""
def __init__(self, config: RequestHandlerConfig):
super().__init__(config)
async def remote_prefill(self, request: dict, context: Context):
async def generate(self, request: dict, context: Context):
"""
Send request to prefill. Try router first if available, fallback to direct worker.
Decode worker: generate tokens using disaggregated_params from prefill.
If disaggregated_params is present, prefill was done. Otherwise generate normally.
"""
# Format request in PreprocessedRequest format with extra_args
prefill_request = copy.deepcopy(request)
logging.debug(f"Decode Request ID: {context.id()}")
# Try router first if available, fallback to worker
if (
self.next_router_client is not None
and self.next_router_client.instance_ids()
):
try:
# Call router's generate endpoint which returns LLMEngineOutput
async for res in await self.next_router_client.generate(
prefill_request, context=context
):
yield res
return
except Exception as e:
logging.warning(
f"Prefill router call failed: {e}. Falling back to direct worker."
)
# Fallback to direct worker
if self.next_client is not None:
async for res in await self.next_client.round_robin(
prefill_request, context=context
):
yield res
else:
raise ValueError("No prefill router or worker available")
async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}")
if self.disaggregation_strategy == DisaggregationStrategy.DECODE_FIRST:
prefill_response = None
# If operating under decode_first strategy, the decode handler needs to trigger
# the prefill handler.
response_count = 0
# Do not yield the prefill response directly.
# Instead, capture it and extract the state.
async for res in self.remote_prefill(request, context):
prefill_response = res
response_count += 1
if response_count > 1:
raise ValueError("Prefill response should be generated only once.")
if context.is_stopped() or context.is_killed():
logging.debug(f"Aborted Remote Request ID: {context.id()}")
return
response_data = (
prefill_response.data() if prefill_response is not None else None
disaggregated_params = request.get("disaggregated_params")
if disaggregated_params:
logging.debug(
f"Using disaggregated params from prefill for request {context.id()}"
)
if prefill_response is not None and self.check_error(response_data):
yield response_data
return
if prefill_response is not None and response_data is not None:
request["disaggregated_params"] = response_data["disaggregated_params"]
# Generate tokens locally (with or without disaggregated_params)
async for res in self.generate_locally(request, context):
yield res
......@@ -10,19 +10,19 @@ from tensorrt_llm.llmapi import BuildConfig
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.config_dump import add_config_dump_args, register_encoder
from dynamo.trtllm import __version__
from dynamo.trtllm.request_handlers.handler_base import (
DisaggregationMode,
DisaggregationStrategy,
)
from dynamo.trtllm.request_handlers.handler_base import DisaggregationMode
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
# Default endpoint for the next worker.
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.tensorrt_llm.generate"
# Default endpoints for TensorRT-LLM workers
DEFAULT_ENDPOINT = (
f"dyn://{DYN_NAMESPACE}.tensorrt_llm.generate" # Decode/aggregated workers
)
DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate" # Prefill workers
DEFAULT_ENCODE_ENDPOINT = (
f"dyn://{DYN_NAMESPACE}.tensorrt_llm_encode.generate" # Encode workers
)
DEFAULT_MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DEFAULT_NEXT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.tensorrt_llm_next.generate"
DEFAULT_ENCODE_ENDPOINT = f"dyn://{DYN_NAMESPACE}.tensorrt_llm_encode.generate"
DEFAULT_DISAGGREGATION_STRATEGY = DisaggregationStrategy.DECODE_FIRST
DEFAULT_DISAGGREGATION_MODE = DisaggregationMode.AGGREGATED
......@@ -50,10 +50,6 @@ class Config:
self.override_engine_args: str = ""
self.publish_events_and_metrics: bool = False
self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE
self.disaggregation_strategy: DisaggregationStrategy = (
DEFAULT_DISAGGREGATION_STRATEGY
)
self.next_endpoint: str = ""
self.encode_endpoint: str = ""
self.modality: str = "text"
self.allowed_local_media_path: str = ""
......@@ -85,8 +81,6 @@ class Config:
f"migration_limit={self.migration_limit}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
f"disaggregation_strategy={self.disaggregation_strategy}, "
f"next_endpoint={self.next_endpoint}, "
f"encode_endpoint={self.encode_endpoint}, "
f"modality={self.modality}, "
f"allowed_local_media_path={self.allowed_local_media_path}, "
......@@ -105,24 +99,6 @@ def _preprocess_for_encode_config(
return obj.__dict__
def is_first_worker(config):
"""
Check if the current worker is the first worker in the disaggregation chain.
"""
is_primary_worker = config.disaggregation_mode == DisaggregationMode.AGGREGATED
if not is_primary_worker:
is_primary_worker = (
config.disaggregation_strategy == DisaggregationStrategy.PREFILL_FIRST
) and (config.disaggregation_mode == DisaggregationMode.PREFILL)
if not is_primary_worker:
is_primary_worker = (
config.disaggregation_strategy == DisaggregationStrategy.DECODE_FIRST
) and (config.disaggregation_mode == DisaggregationMode.DECODE)
return is_primary_worker
def parse_endpoint(endpoint: str) -> tuple[str, str, str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
......@@ -146,7 +122,7 @@ def cmd_line_args():
"--endpoint",
type=str,
default="",
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT} if first worker, {DEFAULT_NEXT_ENDPOINT} if next worker",
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT} for decode/aggregated, {DEFAULT_PREFILL_ENDPOINT} for prefill workers, or {DEFAULT_ENCODE_ENDPOINT} for encode workers",
)
parser.add_argument(
"--model-path",
......@@ -255,13 +231,6 @@ def cmd_line_args():
default=False,
help="Use NIXL Connect for communication between workers.",
)
parser.add_argument(
"--disaggregation-strategy",
type=str,
default=DEFAULT_DISAGGREGATION_STRATEGY,
choices=[strategy.value for strategy in DisaggregationStrategy],
help=f"Strategy to use for disaggregation. Default: {DEFAULT_DISAGGREGATION_STRATEGY}",
)
parser.add_argument(
"--modality",
type=str,
......@@ -269,12 +238,6 @@ def cmd_line_args():
choices=["text", "multimodal"],
help="Modality to use for the model. Default: text. Current supported modalities are image.",
)
parser.add_argument(
"--next-endpoint",
type=str,
default="",
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send requests to when running in disaggregation mode. Default: {DEFAULT_NEXT_ENDPOINT} if first worker, empty if next worker",
)
parser.add_argument(
"--encode-endpoint",
type=str,
......@@ -327,29 +290,18 @@ def cmd_line_args():
# This becomes an `Option` on the Rust side
config.served_model_name = None
# Set the disaggregation mode and strategy.
# Set the disaggregation mode.
config.disaggregation_mode = DisaggregationMode(args.disaggregation_mode)
config.disaggregation_strategy = DisaggregationStrategy(
args.disaggregation_strategy
)
# Set the appropriate defaults for the endpoint and next endpoint.
if is_first_worker(config):
if args.endpoint == "":
args.endpoint = DEFAULT_ENDPOINT
if (
args.next_endpoint == ""
and config.disaggregation_mode != DisaggregationMode.AGGREGATED
):
args.next_endpoint = DEFAULT_NEXT_ENDPOINT
elif config.disaggregation_mode == DisaggregationMode.ENCODE:
# Set the appropriate default for the endpoint based on disaggregation mode
if args.endpoint == "":
if config.disaggregation_mode == DisaggregationMode.ENCODE:
args.endpoint = DEFAULT_ENCODE_ENDPOINT
elif config.disaggregation_mode == DisaggregationMode.PREFILL:
args.endpoint = DEFAULT_PREFILL_ENDPOINT
else:
if args.endpoint == "":
args.endpoint = DEFAULT_NEXT_ENDPOINT
if args.next_endpoint != "":
raise ValueError("Next endpoint is not allowed for the next worker")
# Decode and aggregated workers use "tensorrt_llm" component
args.endpoint = DEFAULT_ENDPOINT
endpoint = args.endpoint
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
......@@ -358,7 +310,6 @@ def cmd_line_args():
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.next_endpoint = args.next_endpoint
config.encode_endpoint = args.encode_endpoint
config.allowed_local_media_path = args.allowed_local_media_path
config.max_file_size_mb = args.max_file_size_mb
......
......@@ -38,7 +38,6 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
- [Quick Start](#quick-start)
- [Single Node Examples](#single-node-examples)
- [Advanced Examples](#advanced-examples)
- [Disaggregation Strategy](#disaggregation-strategy)
- [KV Cache Transfer](#kv-cache-transfer-in-disaggregated-serving)
- [Client](#client)
- [Benchmarking](#benchmarking)
......@@ -124,7 +123,7 @@ This figure shows an overview of the major components to deploy:
+------------------+
```
**Note:** The diagram above shows all possible components in a deployment. Depending on the chosen disaggregation strategy, you can configure whether Worker1 handles prefill and Worker2 handles decode, or vice versa. For more information on how to select and configure these strategies, see the [Disaggregation Strategy](#disaggregation-strategy) section below.
**Note:** The diagram above shows all possible components in a deployment. In disaggregated serving, Worker1 acts as the decode worker and Worker2 as the prefill worker, with the unified frontend coordinating request routing between them.
### Aggregated
```bash
......@@ -140,9 +139,6 @@ cd $DYNAMO_HOME/examples/backends/trtllm
### Disaggregated
> [!IMPORTANT]
> Disaggregated serving supports two strategies for request flow: `"prefill_first"` and `"decode_first"`. By default, the script below uses the `"decode_first"` strategy, which can reduce response latency by minimizing extra hops in the return path. You can switch strategies by setting the `DISAGGREGATION_STRATEGY` environment variable.
```bash
cd $DYNAMO_HOME/examples/backends/trtllm
./launch/disagg.sh
......@@ -151,7 +147,7 @@ cd $DYNAMO_HOME/examples/backends/trtllm
### Disaggregated with KV Routing
> [!IMPORTANT]
> Disaggregated serving with KV routing uses a "prefill first" workflow by default. Currently, Dynamo supports KV routing to only one endpoint per model. In disaggregated workflow, it is generally more effective to route requests to the prefill worker. If you wish to use a "decode first" workflow instead, you can simply set the `DISAGGREGATION_STRATEGY` environment variable accordingly.
> In disaggregated workflow, requests are routed to the prefill worker to maximize KV cache reuse.
```bash
cd $DYNAMO_HOME/examples/backends/trtllm
......@@ -199,20 +195,6 @@ NOTE: To send a request to a multi-node deployment, target the node which is run
To benchmark your deployment with AIPerf, see this utility script, configuring the
`model` name and `host` based on your deployment: [perf.sh](../../../benchmarks/llm/perf.sh)
## Disaggregation Strategy
The disaggregation strategy controls how requests are distributed between the prefill and decode workers in a disaggregated deployment.
By default, Dynamo uses a `decode first` strategy: incoming requests are initially routed to the decode worker, which then forwards them to the prefill worker in round-robin fashion. The prefill worker processes the request and returns results to the decode worker for any remaining decode operations.
When using KV routing, however, Dynamo switches to a `prefill first` strategy. In this mode, requests are routed directly to the prefill worker, which can help maximize KV cache reuse and improve overall efficiency for certain workloads. Choosing the appropriate strategy can have a significant impact on performance, depending on your use case.
The disaggregation strategy can be set using the `DISAGGREGATION_STRATEGY` environment variable. You can set the strategy before launching your deployment, for example:
```bash
DISAGGREGATION_STRATEGY="prefill_first" ./launch/disagg.sh
```
## KV Cache Transfer in Disaggregated Serving
Dynamo with TensorRT-LLM supports two methods for transferring KV cache in disaggregated serving: UCX (default) and NIXL (experimental). For detailed information and configuration instructions for each method, see the [KV cache transfer guide](./kv-cache-transfer.md).
......@@ -223,10 +205,14 @@ Dynamo with TensorRT-LLM supports two methods for transferring KV cache in disag
You can enable [request migration](../../../docs/fault_tolerance/request_migration.md) to handle worker failures gracefully. Use the `--migration-limit` flag to specify how many times a request can be migrated to another worker:
```bash
# For decode and aggregated workers
python3 -m dynamo.trtllm ... --migration-limit=3
```
This allows a request to be migrated up to 3 times before failing. See the [Request Migration Architecture](../../../docs/fault_tolerance/request_migration.md) documentation for details on how this works.
> [!IMPORTANT]
> **Prefill workers do not support request migration** and must use `--migration-limit=0` (the default). Prefill workers only process prompts and return KV cache state - they don't maintain long-running generation requests that would benefit from migration.
See the [Request Migration Architecture](../../../docs/fault_tolerance/request_migration.md) documentation for details on how this works.
## Request Cancellation
......@@ -237,8 +223,7 @@ When a user cancels a request (e.g., by disconnecting from the frontend), the re
| | Prefill | Decode |
|-|---------|--------|
| **Aggregated** | ✅ | ✅ |
| **Disaggregated (Decode-First)** | ✅ | ✅ |
| **Disaggregated (Prefill-First)** | ✅ | ✅ |
| **Disaggregated** | ✅ | ✅ |
For more details, see the [Request Cancellation Architecture](../../fault_tolerance/request_cancellation.md) documentation.
......
......@@ -149,7 +149,6 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m dynamo.trtllm \
--dyn-reasoning-parser gpt_oss \
--dyn-tool-call-parser harmony \
--disaggregation-mode prefill \
--disaggregation-strategy prefill_first \
--max-num-tokens 20000 \
--max-batch-size 32 \
--free-gpu-memory-fraction 0.9 \
......@@ -166,7 +165,6 @@ CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m dynamo.trtllm \
--dyn-reasoning-parser gpt_oss \
--dyn-tool-call-parser harmony \
--disaggregation-mode decode \
--disaggregation-strategy prefill_first \
--max-num-tokens 16384 \
--free-gpu-memory-fraction 0.9 \
--tensor-parallel-size 4 \
......@@ -185,7 +183,7 @@ Make sure that both of the endpoints are available before sending an inference r
{
"endpoints": [
"dyn://dynamo.tensorrt_llm.generate",
"dyn://dynamo.tensorrt_llm_next.generate"
"dyn://dynamo.prefill.generate"
],
"status": "healthy"
}
......
......@@ -27,8 +27,6 @@ This guide demonstrates how to deploy Llama 4 Maverick Instruct with Eagle Specu
- One node runs the decode worker.
- The other node runs the prefill worker.
For advanced control over how requests are routed between prefill and decode workers in disaggregated mode, refer to the [Disaggregation Strategy](./README.md#disaggregation-strategy) section.
## Notes
* Make sure the (`eagle3_one_model: true`) is set in the LLM API config inside the `recipes/llama4/trtllm/eagle` folder.
......
......@@ -57,23 +57,21 @@ The EPD flow implements a **3-worker architecture** for high-performance multimo
- **Prefill Worker**: Handles initial context processing and KV-cache generation
- **Decode Worker**: Performs streaming token generation
## Request Flow Diagrams
### Prefill-First Disaggregation Strategy
## Request Flow Diagram
```mermaid
sequenceDiagram
participant Client
participant Gateway
participant PrefillWorker as "Prefill Worker<br/>(AggregatedHandler)"
participant Frontend
participant PrefillWorker as "Prefill Worker<br/>(PrefillHandler)"
participant EncodeWorker as "Encode Worker<br/>(EncodeHandler)"
participant DecodeWorker as "Decode Worker<br/>(DecodeHandler)"
participant NIXL as "NIXL<br/>(RDMA Transfer)"
Note over Client,NIXL: Prefill-First Strategy: Context processing first, then streaming generation
Note over Client,NIXL: Unified Frontend: Context processing followed by streaming generation
Client->>Gateway: POST /v1/chat/completions<br/>(multimodal request)
Gateway->>PrefillWorker: Route request
Client->>Frontend: POST /v1/chat/completions<br/>(multimodal request)
Frontend->>PrefillWorker: Route to prefill worker
Note over PrefillWorker: Check for multimodal content
PrefillWorker->>EncodeWorker: Send request<br/>(contains embedding paths)
......@@ -90,74 +88,24 @@ sequenceDiagram
Note over PrefillWorker: Process full context<br/>(text + multimodal embeddings)
Note over PrefillWorker: Generate KV-cache<br/>(max_tokens=1 in prefill mode)
PrefillWorker->>DecodeWorker: Transfer KV-cache + disaggregated_params<br/>(generation_only mode)
Note over DecodeWorker: Continue generation<br/>(streaming tokens)
DecodeWorker->>Gateway: Stream response chunk 1
Gateway->>Client: Response chunk 1
DecodeWorker->>Gateway: Stream response chunk 2
Gateway->>Client: Response chunk 2
DecodeWorker->>Gateway: ... (continue streaming)
Gateway->>Client: ... (continue streaming)
DecodeWorker->>Gateway: Final response + [DONE]
Gateway->>Client: Final response + [DONE]
```
### Decode-First Disaggregation Strategy
```mermaid
sequenceDiagram
participant Client
participant Gateway
participant DecodeWorker as "Decode Worker<br/>(DecodeHandler)<br/>PRIMARY"
participant PrefillWorker as "Prefill Worker<br/>(PrefillHandler)"
participant EncodeWorker as "Encode Worker<br/>(EncodeHandler)"
participant NIXL as "NIXL<br/>(RDMA Transfer)"
Note over Client,NIXL: Decode-First Strategy: DecodeWorker orchestrates prefill then handles generation
Client->>Gateway: POST /v1/chat/completions<br/>(multimodal request)
Gateway->>DecodeWorker: Route request<br/>(primary worker)
PrefillWorker->>Frontend: Return prefill response<br/>(disaggregated_params)
Note over DecodeWorker: Check disaggregation_strategy == DECODE_FIRST
Note over DecodeWorker: Call remote_prefill() to trigger prefill
Frontend->>DecodeWorker: Route to decode worker<br/>with disaggregated_params
DecodeWorker->>PrefillWorker: Send request via remote_prefill()
Note over PrefillWorker: Check for multimodal content
PrefillWorker->>EncodeWorker: Send request<br/>(contains embedding paths)
Note over EncodeWorker: Load embeddings from file<br/>
EncodeWorker->>NIXL: Create readable operation<br/>
EncodeWorker->>PrefillWorker: Send metadata + NIXL info<br/>(JSON: shape, dtype, aux_data)
Note over PrefillWorker: Allocate tensor with dynamic shape
PrefillWorker->>NIXL: Begin read operation
NIXL-->>PrefillWorker: Zero-copy transfer complete<br/>
Note over PrefillWorker: Reconstruct embeddings<br/>(mm_embeddings + special_tokens + offsets)
Note over PrefillWorker: Process full context<br/>(text + multimodal embeddings)
Note over PrefillWorker: Generate prefill response<br/>(max_tokens=1 in prefill mode)
PrefillWorker->>DecodeWorker: Return prefill response<br/>(disaggregated_params)
Note over DecodeWorker: Extract disaggregated_params<br/>from prefill_response
Note over DecodeWorker: Update request with params<br/>request["disaggregated_params"] = response_data["disaggregated_params"]
Note over DecodeWorker: Begin local generation<br/>(generate_locally with prefill state)
DecodeWorker->>Gateway: Stream response chunk 1
Gateway->>Client: Response chunk 1
DecodeWorker->>Gateway: Stream response chunk 2
Gateway->>Client: Response chunk 2
DecodeWorker->>Gateway: ... (continue streaming)
Gateway->>Client: ... (continue streaming)
DecodeWorker->>Gateway: Final response + [DONE]
Gateway->>Client: Final response + [DONE]
Note over DecodeWorker: Continue generation<br/>(streaming tokens)
DecodeWorker->>Frontend: Stream response chunk 1
Frontend->>Client: Response chunk 1
DecodeWorker->>Frontend: Stream response chunk 2
Frontend->>Client: Response chunk 2
DecodeWorker->>Frontend: ... (continue streaming)
Frontend->>Client: ... (continue streaming)
DecodeWorker->>Frontend: Final response + [DONE]
Frontend->>Client: Final response + [DONE]
```
## How the System Works
1. **Request Processing**: Multimodal requests containing embedding file paths OR urls are routed based on disaggregation strategy
1. **Request Processing**: Multimodal requests containing embedding file paths or URLs are routed by the frontend to prefill workers
2. **Multimodal Loading**: EncodeWorker loads large embedding files and extracts auxiliary metadata
3. **NIXL Transfer**: Main tensors transferred via zero-copy RDMA, small metadata via JSON for efficiency
4. **Dynamic Allocation**: Consumer workers allocate tensors with exact shapes received from EncodeWorker
......
......@@ -79,7 +79,6 @@ cd $DYNAMO_HOME
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2-VL-7B-Instruct"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen2-VL-7B-Instruct"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"recipes/qwen2-vl-7b-instruct/trtllm/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"recipes/qwen2-vl-7b-instruct/trtllm/decode.yaml"}
export MODALITY=${MODALITY:-"multimodal"}
......
......@@ -212,24 +212,9 @@ spec:
TensorRT-LLM workers are configured through command-line arguments in the deployment YAML. Key configuration areas include:
- **Disaggregation Strategy**: Control request flow with `DISAGGREGATION_STRATEGY` environment variable
- **KV Cache Transfer**: Choose between UCX (default) or NIXL for disaggregated serving
- **Request Migration**: Enable graceful failure handling with `--migration-limit`
### Disaggregation Strategy
The disaggregation strategy controls how requests are distributed between prefill and decode workers:
- **`decode_first`** (default): Requests routed to decode worker first, then forwarded to prefill worker
- **`prefill_first`**: Requests routed directly to prefill worker (used with KV routing)
Set via environment variable:
```yaml
envs:
- name: DISAGGREGATION_STRATEGY
value: "prefill_first"
```
## Testing the Deployment
Send a test request to verify your deployment. See the [client section](../../../../docs/backends/vllm/README.md#client) for detailed instructions.
......
......@@ -142,8 +142,6 @@ spec:
- /workspace/prefill.yaml
- --disaggregation-mode
- prefill
- --disaggregation-strategy
- decode_first
decode:
volumeMounts:
- name: models
......@@ -182,5 +180,3 @@ spec:
- /workspace/decode.yaml
- --disaggregation-mode
- decode
- --disaggregation-strategy
- decode_first
......@@ -40,8 +40,6 @@ spec:
- ./recipes/qwen3/trtllm/prefill.yaml
- --disaggregation-mode
- prefill
- --disaggregation-strategy
- decode_first
TRTLLMDecodeWorker:
dynamoNamespace: trtllm-disagg
envFromSecret: hf-token-secret
......@@ -68,5 +66,3 @@ spec:
- ./recipes/qwen3/trtllm/decode.yaml
- --disaggregation-mode
- decode
- --disaggregation-strategy
- decode_first
......@@ -100,8 +100,6 @@ spec:
- ./recipes/qwen3/trtllm/decode.yaml
- --disaggregation-mode
- decode
- --disaggregation-strategy
- decode_first
TRTLLMPrefillWorker:
dynamoNamespace: trtllm-disagg-planner
envFromSecret: hf-token-secret
......@@ -129,5 +127,3 @@ spec:
- ./recipes/qwen3/trtllm/prefill.yaml
- --disaggregation-mode
- prefill
- --disaggregation-strategy
- decode_first
......@@ -42,14 +42,12 @@ spec:
- ./recipes/qwen3/trtllm/prefill.yaml
- --disaggregation-mode
- prefill
- --disaggregation-strategy
- prefill_first
- --publish-events-and-metrics
TRTLLMDecodeWorker:
dynamoNamespace: trtllm-v1-disagg-router
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
replicas: 2
resources:
limits:
gpu: "1"
......@@ -70,5 +68,3 @@ spec:
- ./recipes/qwen3/trtllm/decode.yaml
- --disaggregation-mode
- decode
- --disaggregation-strategy
- prefill_first
......@@ -6,7 +6,6 @@
export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"}
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen3-0.6B"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen3-0.6B"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"$DYNAMO_HOME/recipes/qwen3/trtllm/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"$DYNAMO_HOME/recipes/qwen3/trtllm/decode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
......@@ -34,7 +33,6 @@ CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--modality "$MODALITY" \
--disaggregation-mode prefill &
PREFILL_PID=$!
......@@ -44,6 +42,5 @@ CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--modality "$MODALITY" \
--disaggregation-mode decode
......@@ -6,7 +6,6 @@
export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"}
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen3-0.6B"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen3-0.6B"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"prefill_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"$DYNAMO_HOME/recipes/qwen3/trtllm/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"$DYNAMO_HOME/recipes/qwen3/trtllm/decode.yaml"}
export PREFILL_CUDA_VISIBLE_DEVICES=${PREFILL_CUDA_VISIBLE_DEVICES:-"0"}
......@@ -22,34 +21,25 @@ cleanup() {
trap cleanup EXIT INT TERM
# run frontend
# run frontend with KV routing for cache-aware optimization
python3 -m dynamo.frontend --router-mode kv --http-port 8000 &
DYNAMO_PID=$!
EXTRA_PREFILL_ARGS=()
EXTRA_DECODE_ARGS=()
if [ "$DISAGGREGATION_STRATEGY" == "prefill_first" ]; then
EXTRA_PREFILL_ARGS+=(--publish-events-and-metrics)
else
EXTRA_DECODE_ARGS+=(--publish-events-and-metrics)
fi
# run prefill worker
# Publishes KV events for router's cache-aware routing
# No next_endpoint needed - unified frontend handles routing
CUDA_VISIBLE_DEVICES=$PREFILL_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \
--disaggregation-mode prefill \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
"${EXTRA_PREFILL_ARGS[@]}" &
--publish-events-and-metrics &
PREFILL_PID=$!
# run decode worker
# No event publishing needed - prefill handles it
CUDA_VISIBLE_DEVICES=$DECODE_CUDA_VISIBLE_DEVICES python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \
--disaggregation-mode decode \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
"${EXTRA_DECODE_ARGS[@]}"
--disaggregation-mode decode
......@@ -32,7 +32,6 @@ echo "GPU memory check passed: ${FREE_GPU_GB}GB available (required: ${REQUIRED_
export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"}
export MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen3-0.6B"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"Qwen/Qwen3-0.6B"}
export DISAGGREGATION_STRATEGY=${DISAGGREGATION_STRATEGY:-"decode_first"}
export PREFILL_ENGINE_ARGS=${PREFILL_ENGINE_ARGS:-"$DYNAMO_HOME/tests/serve/configs/trtllm/prefill.yaml"}
export DECODE_ENGINE_ARGS=${DECODE_ENGINE_ARGS:-"$DYNAMO_HOME/tests/serve/configs/trtllm/decode.yaml"}
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"}
......@@ -59,7 +58,6 @@ python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$PREFILL_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--modality "$MODALITY" \
--publish-events-and-metrics \
--disaggregation-mode prefill &
......@@ -72,7 +70,6 @@ python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--extra-engine-args "$DECODE_ENGINE_ARGS" \
--disaggregation-strategy "$DISAGGREGATION_STRATEGY" \
--modality "$MODALITY" \
--publish-events-and-metrics \
--disaggregation-mode decode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment