Unverified Commit f3987f6c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: prefill routing for trtllm (#3471)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent cf83794a
......@@ -38,17 +38,22 @@ This will start both etcd and NATS with the required configurations in the backg
## Usage Instructions
### Step 1: Launch vLLM Workers
### Step 1: Launch Workers
Make sure you have 8 GPUs for these examples, unless you are using mockers (see below). First, start the vLLM worker engines in a terminal.
Make sure you have 8 GPUs for these examples, unless you are using mockers (see below). First, start the worker engines in a terminal.
The script supports three modes:
- **`agg` (default)**: Aggregated/monolithic workers that handle both prefill and decode
- **`decode`**: Workers dedicated to decode (token generation) phase
- **`prefill`**: Workers dedicated to prefill (prompt processing) phase
```bash
# Default: 8 vLLM workers with DeepSeek model (explicitly sets --block-size 64)
# Default: 8 aggregated workers with DeepSeek model (handles both prefill and decode)
./run_engines.sh \
--num-workers 8 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Example: 4 vLLM workers with larger model using tensor parallelism (2 GPUs per worker)
# Example: 4 workers with larger model using tensor parallelism (2 GPUs per worker)
# NOTE: this requires having Hopper or later GPU SKUs to support MXFP4 precision.
./run_engines.sh \
--num-workers 4 \
......@@ -56,19 +61,20 @@ Make sure you have 8 GPUs for these examples, unless you are using mockers (see
--tensor-parallel-size 2
```
#### Prefill Workers
#### Disaggregated Serving (Decode + Prefill Workers)
You can also launch separate decode and prefill workers for disaggregated serving. This allows you to dedicate specific GPUs to prefill (prompt processing) and decode (token generation) tasks:
You can launch separate decode and prefill workers for disaggregated serving. This allows you to dedicate specific GPUs to prefill (prompt processing) and decode (token generation) tasks:
```bash
# Launch 4 decode workers (GPUs 0-3)
./run_engines.sh \
--decode \
--num-workers 4 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Launch 4 prefill workers (GPUs 4-7)
./run_engines.sh \
--prefills \
--prefill \
--num-workers 4 \
--base-gpu-offset 4 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
......
......@@ -8,7 +8,8 @@ NUM_WORKERS=8
MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
TENSOR_PARALLEL_SIZE=1
USE_MOCKERS=false
USE_PREFILLS=false
USE_TRTLLM=false
MODE="agg" # Options: agg (default), decode, prefill
BASE_GPU_OFFSET=0
EXTRA_ARGS=()
......@@ -31,8 +32,16 @@ while [[ $# -gt 0 ]]; do
USE_MOCKERS=true
shift
;;
--prefills)
USE_PREFILLS=true
--trtllm)
USE_TRTLLM=true
shift
;;
--prefill)
MODE="prefill"
shift
;;
--decode)
MODE="decode"
shift
;;
--base-gpu-offset)
......@@ -45,13 +54,22 @@ while [[ $# -gt 0 ]]; do
break
;;
*)
# Collect all other arguments as vLLM/mocker arguments
# Collect all other arguments as vLLM/mocker/trtllm arguments
EXTRA_ARGS+=("$1")
shift
;;
esac
done
# Validate that only one engine type is selected
ENGINE_COUNT=0
[ "$USE_MOCKERS" = true ] && ((ENGINE_COUNT++))
[ "$USE_TRTLLM" = true ] && ((ENGINE_COUNT++))
if [ "$ENGINE_COUNT" -gt 1 ]; then
echo "Error: Only one engine type (--mockers, --trtllm, or default vLLM) can be specified"
exit 1
fi
# If no extra args provided, use defaults
if [ ${#EXTRA_ARGS[@]} -eq 0 ]; then
if [ "$USE_MOCKERS" = true ]; then
......@@ -59,6 +77,21 @@ if [ ${#EXTRA_ARGS[@]} -eq 0 ]; then
EXTRA_ARGS=(
"--block-size" "64"
)
elif [ "$USE_TRTLLM" = true ]; then
# Default args for TensorRT-LLM engine using predefined YAML configs
# Config files located at: ../../components/backends/trtllm/engine_configs/{agg,decode,prefill}.yaml
if [ "$MODE" = "prefill" ]; then
ENGINE_CONFIG="../../components/backends/trtllm/engine_configs/prefill.yaml"
elif [ "$MODE" = "decode" ]; then
ENGINE_CONFIG="../../components/backends/trtllm/engine_configs/decode.yaml"
else
ENGINE_CONFIG="../../components/backends/trtllm/engine_configs/agg.yaml"
fi
EXTRA_ARGS=(
"--extra-engine-args" "$ENGINE_CONFIG"
"--publish-events-and-metrics"
)
else
# Default args for vLLM engine (explicitly include block-size)
EXTRA_ARGS=(
......@@ -90,8 +123,15 @@ fi
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE))
LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1))
echo "Configuration:"
echo " Engine Type: $([ "$USE_MOCKERS" = true ] && echo "Mocker" || echo "vLLM")"
echo " Worker Type: $([ "$USE_PREFILLS" = true ] && echo "Prefill" || echo "Decode")"
if [ "$USE_MOCKERS" = true ]; then
ENGINE_TYPE="Mocker"
elif [ "$USE_TRTLLM" = true ]; then
ENGINE_TYPE="TensorRT-LLM"
else
ENGINE_TYPE="vLLM"
fi
echo " Engine Type: $ENGINE_TYPE"
echo " Mode: $MODE"
echo " Workers: $NUM_WORKERS"
echo " Model: $MODEL_PATH"
echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
......@@ -111,12 +151,11 @@ cleanup() {
trap cleanup SIGINT SIGTERM
WORKER_TYPE=$([ "$USE_PREFILLS" = true ] && echo "prefill" || echo "decode")
echo "Starting $NUM_WORKERS $WORKER_TYPE workers..."
echo "Starting $NUM_WORKERS $MODE workers..."
for i in $(seq 1 $NUM_WORKERS); do
{
echo "[${WORKER_TYPE^} Worker-$i] Starting..."
echo "[${MODE^} Worker-$i] Starting..."
# Calculate GPU indices for this worker (with base offset)
START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * TENSOR_PARALLEL_SIZE ))
......@@ -142,13 +181,26 @@ for i in $(seq 1 $NUM_WORKERS); do
--model-path "$MODEL_PATH" \
--endpoint dyn://test.mocker.generate \
"${EXTRA_ARGS[@]}"
elif [ "$USE_TRTLLM" = true ]; then
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES"
# Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization
TRTLLM_ARGS=()
TRTLLM_ARGS+=("--model-path" "$MODEL_PATH")
TRTLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$MODE" != "agg" ]; then
TRTLLM_ARGS+=("--disaggregation-mode" "$MODE")
fi
TRTLLM_ARGS+=("${EXTRA_ARGS[@]}")
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \
"${TRTLLM_ARGS[@]}"
else
echo "[${WORKER_TYPE^} Worker-$i] Using GPUs: $GPU_DEVICES"
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES"
# Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing
VLLM_ARGS=()
VLLM_ARGS+=("--model" "$MODEL_PATH")
VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$USE_PREFILLS" = true ]; then
if [ "$MODE" = "prefill" ]; then
VLLM_ARGS+=("--is-prefill-worker")
fi
VLLM_ARGS+=("${EXTRA_ARGS[@]}")
......@@ -158,7 +210,7 @@ for i in $(seq 1 $NUM_WORKERS); do
fi
} &
PIDS+=($!)
echo "Started $WORKER_TYPE worker $i (PID: $!)"
echo "Started $MODE worker $i (PID: $!)"
done
echo "All workers started. Press Ctrl+C to stop."
......
......@@ -98,6 +98,7 @@ class StandaloneRouterHandler:
"output_options": request.get("output_options", {}),
"eos_token_ids": request.get("eos_token_ids", []),
"annotations": request.get("annotations", []),
"disaggregated_params": request.get("disaggregated_params"),
"extra_args": request.get("extra_args", {}),
}
......@@ -116,6 +117,7 @@ class StandaloneRouterHandler:
"top_logprobs": worker_output.get("top_logprobs"),
"finish_reason": worker_output.get("finish_reason"),
"index": worker_output.get("index"),
"disaggregated_params": worker_output.get("disaggregated_params"),
"extra_args": worker_output.get("extra_args"),
}
yield llm_engine_output
......
......@@ -138,6 +138,22 @@ async def init(runtime: DistributedRuntime, config: Config):
.client()
)
# Set up prefill router client for decode workers
next_router_client = None
if config.disaggregation_mode.value == "decode":
try:
logging.info("Initializing prefill router client")
next_router_client = (
await runtime.namespace(config.namespace)
.component("router") # Standalone router for prefill workers
.endpoint("generate")
.client()
)
logging.info("Prefill router client initialized successfully")
except Exception as e:
logging.warning(f"Failed to initialize prefill router client: {e}")
logging.info("Will use direct prefill worker client only")
encode_client = None
if config.encode_endpoint:
logging.info(
......@@ -329,6 +345,7 @@ async def init(runtime: DistributedRuntime, config: Config):
disaggregation_mode=config.disaggregation_mode,
disaggregation_strategy=config.disaggregation_strategy,
next_client=next_client,
next_router_client=next_router_client,
encode_client=encode_client,
multimodal_processor=multimodal_processor,
connector=connector,
......@@ -357,13 +374,15 @@ async def init(runtime: DistributedRuntime, config: Config):
# Get health check payload (checks env var and falls back to TensorRT-LLM default)
health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer).to_dict()
if config.publish_events_and_metrics and is_first_worker(config):
if config.publish_events_and_metrics:
# Initialize and pass in the publisher to the request handler to
# publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component(
config.component
)
metrics_labels = [("model", config.served_model_name)]
# Use model_path as fallback if served_model_name is not provided
model_name_for_metrics = config.served_model_name or config.model_path
metrics_labels = [("model", model_name_for_metrics)]
async with get_publisher(
component,
engine,
......
......@@ -68,6 +68,7 @@ class RequestHandlerConfig:
disaggregation_mode: DisaggregationMode
disaggregation_strategy: DisaggregationStrategy
next_client: object
next_router_client: Optional[object] = None
encode_client: Optional[object] = None
multimodal_processor: Optional[
MultimodalRequestProcessor
......@@ -88,6 +89,7 @@ class HandlerBase:
self.disaggregation_mode = config.disaggregation_mode
self.disaggregation_strategy = config.disaggregation_strategy
self.next_client = config.next_client
self.next_router_client = config.next_router_client
self.encode_client = config.encode_client
self.multimodal_processor = config.multimodal_processor
self.first_generation = True
......
......@@ -191,8 +191,37 @@ class DecodeHandler(HandlerBase):
super().__init__(config)
async def remote_prefill(self, request: dict, context: Context):
async for res in await self.next_client.round_robin(request, context=context):
"""
Send request to prefill. Try router first if available, fallback to direct worker.
"""
# Format request in PreprocessedRequest format with extra_args
prefill_request = copy.deepcopy(request)
# Try router first if available, fallback to worker
if (
self.next_router_client is not None
and self.next_router_client.instance_ids()
):
try:
# Call router's generate endpoint which returns LLMEngineOutput
async for res in await self.next_router_client.generate(
prefill_request, context=context
):
yield res
return
except Exception as e:
logging.warning(
f"Prefill router call failed: {e}. Falling back to direct worker."
)
# Fallback to direct worker
if self.next_client is not None:
async for res in await self.next_client.round_robin(
prefill_request, context=context
):
yield res
else:
raise ValueError("No prefill router or worker available")
async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}")
......
......@@ -271,6 +271,7 @@ fn run_request(
top_logprobs: None,
finish_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
};
work_request
......
......@@ -302,7 +302,16 @@ pub async fn start_kv_router_background(
let key = String::from_utf8_lossy(kv.key());
tracing::info!("Router deleted: {}", key);
// Extract the router UUID from the key (format: kv_routers/<model>/<uuid>)
// Only process deletions for routers on the same component
if !key.contains(component.path().as_str()) {
tracing::trace!(
"Skipping router deletion from different component (key: {key}, subscriber component: {})",
component.path()
);
continue;
}
// Extract the router UUID from the key
let Some(router_uuid) = key.split('/').next_back() else {
tracing::warn!("Could not extract UUID from router key: {}", key);
continue;
......
......@@ -210,6 +210,7 @@ mod tests {
top_logprobs: None,
finish_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
})
}
......
......@@ -392,6 +392,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
top_logprobs: None,
finish_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
};
......
......@@ -85,6 +85,10 @@ pub struct LLMEngineOutput {
// Index field for batch requests to match OpenAI format
pub index: Option<u32>,
/// Disaggregated execution parameters (for prefill/decode separation)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
/// Additional arguments for extensibility
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>,
......@@ -101,6 +105,7 @@ impl LLMEngineOutput {
top_logprobs: None,
finish_reason: Some(FinishReason::Cancelled),
index: None,
disaggregated_params: None,
extra_args: None,
}
}
......@@ -115,6 +120,7 @@ impl LLMEngineOutput {
finish_reason: Some(FinishReason::Stop),
top_logprobs: None,
index: None,
disaggregated_params: None,
extra_args: None,
}
}
......@@ -129,6 +135,7 @@ impl LLMEngineOutput {
top_logprobs: None,
finish_reason: Some(FinishReason::Length),
index: None,
disaggregated_params: None,
extra_args: None,
}
}
......@@ -143,6 +150,7 @@ impl LLMEngineOutput {
top_logprobs: None,
finish_reason: Some(FinishReason::Error(err_msg)),
index: None,
disaggregated_params: None,
extra_args: None,
}
}
......
......@@ -60,6 +60,11 @@ pub struct PreprocessedRequest {
#[builder(default)]
pub router_config_override: Option<RouterConfigOverride>,
/// Disaggregated execution parameters (for prefill/decode separation)
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
/// Additional arguments for extensibility
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment