Unverified Commit f3987f6c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: prefill routing for trtllm (#3471)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent cf83794a
...@@ -38,17 +38,22 @@ This will start both etcd and NATS with the required configurations in the backg ...@@ -38,17 +38,22 @@ This will start both etcd and NATS with the required configurations in the backg
## Usage Instructions ## Usage Instructions
### Step 1: Launch vLLM Workers ### Step 1: Launch Workers
Make sure you have 8 GPUs for these examples, unless you are using mockers (see below). First, start the vLLM worker engines in a terminal. Make sure you have 8 GPUs for these examples, unless you are using mockers (see below). First, start the worker engines in a terminal.
The script supports three modes:
- **`agg` (default)**: Aggregated/monolithic workers that handle both prefill and decode
- **`decode`**: Workers dedicated to decode (token generation) phase
- **`prefill`**: Workers dedicated to prefill (prompt processing) phase
```bash ```bash
# Default: 8 vLLM workers with DeepSeek model (explicitly sets --block-size 64) # Default: 8 aggregated workers with DeepSeek model (handles both prefill and decode)
./run_engines.sh \ ./run_engines.sh \
--num-workers 8 \ --num-workers 8 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Example: 4 vLLM workers with larger model using tensor parallelism (2 GPUs per worker) # Example: 4 workers with larger model using tensor parallelism (2 GPUs per worker)
# NOTE: this requires having Hopper or later GPU SKUs to support MXFP4 precision. # NOTE: this requires having Hopper or later GPU SKUs to support MXFP4 precision.
./run_engines.sh \ ./run_engines.sh \
--num-workers 4 \ --num-workers 4 \
...@@ -56,19 +61,20 @@ Make sure you have 8 GPUs for these examples, unless you are using mockers (see ...@@ -56,19 +61,20 @@ Make sure you have 8 GPUs for these examples, unless you are using mockers (see
--tensor-parallel-size 2 --tensor-parallel-size 2
``` ```
#### Prefill Workers #### Disaggregated Serving (Decode + Prefill Workers)
You can also launch separate decode and prefill workers for disaggregated serving. This allows you to dedicate specific GPUs to prefill (prompt processing) and decode (token generation) tasks: You can launch separate decode and prefill workers for disaggregated serving. This allows you to dedicate specific GPUs to prefill (prompt processing) and decode (token generation) tasks:
```bash ```bash
# Launch 4 decode workers (GPUs 0-3) # Launch 4 decode workers (GPUs 0-3)
./run_engines.sh \ ./run_engines.sh \
--decode \
--num-workers 4 \ --num-workers 4 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Launch 4 prefill workers (GPUs 4-7) # Launch 4 prefill workers (GPUs 4-7)
./run_engines.sh \ ./run_engines.sh \
--prefills \ --prefill \
--num-workers 4 \ --num-workers 4 \
--base-gpu-offset 4 \ --base-gpu-offset 4 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
......
...@@ -8,7 +8,8 @@ NUM_WORKERS=8 ...@@ -8,7 +8,8 @@ NUM_WORKERS=8
MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B" MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
TENSOR_PARALLEL_SIZE=1 TENSOR_PARALLEL_SIZE=1
USE_MOCKERS=false USE_MOCKERS=false
USE_PREFILLS=false USE_TRTLLM=false
MODE="agg" # Options: agg (default), decode, prefill
BASE_GPU_OFFSET=0 BASE_GPU_OFFSET=0
EXTRA_ARGS=() EXTRA_ARGS=()
...@@ -31,8 +32,16 @@ while [[ $# -gt 0 ]]; do ...@@ -31,8 +32,16 @@ while [[ $# -gt 0 ]]; do
USE_MOCKERS=true USE_MOCKERS=true
shift shift
;; ;;
--prefills) --trtllm)
USE_PREFILLS=true USE_TRTLLM=true
shift
;;
--prefill)
MODE="prefill"
shift
;;
--decode)
MODE="decode"
shift shift
;; ;;
--base-gpu-offset) --base-gpu-offset)
...@@ -45,13 +54,22 @@ while [[ $# -gt 0 ]]; do ...@@ -45,13 +54,22 @@ while [[ $# -gt 0 ]]; do
break break
;; ;;
*) *)
# Collect all other arguments as vLLM/mocker arguments # Collect all other arguments as vLLM/mocker/trtllm arguments
EXTRA_ARGS+=("$1") EXTRA_ARGS+=("$1")
shift shift
;; ;;
esac esac
done done
# Validate that only one engine type is selected
ENGINE_COUNT=0
[ "$USE_MOCKERS" = true ] && ((ENGINE_COUNT++))
[ "$USE_TRTLLM" = true ] && ((ENGINE_COUNT++))
if [ "$ENGINE_COUNT" -gt 1 ]; then
echo "Error: Only one engine type (--mockers, --trtllm, or default vLLM) can be specified"
exit 1
fi
# If no extra args provided, use defaults # If no extra args provided, use defaults
if [ ${#EXTRA_ARGS[@]} -eq 0 ]; then if [ ${#EXTRA_ARGS[@]} -eq 0 ]; then
if [ "$USE_MOCKERS" = true ]; then if [ "$USE_MOCKERS" = true ]; then
...@@ -59,6 +77,21 @@ if [ ${#EXTRA_ARGS[@]} -eq 0 ]; then ...@@ -59,6 +77,21 @@ if [ ${#EXTRA_ARGS[@]} -eq 0 ]; then
EXTRA_ARGS=( EXTRA_ARGS=(
"--block-size" "64" "--block-size" "64"
) )
elif [ "$USE_TRTLLM" = true ]; then
# Default args for TensorRT-LLM engine using predefined YAML configs
# Config files located at: ../../components/backends/trtllm/engine_configs/{agg,decode,prefill}.yaml
if [ "$MODE" = "prefill" ]; then
ENGINE_CONFIG="../../components/backends/trtllm/engine_configs/prefill.yaml"
elif [ "$MODE" = "decode" ]; then
ENGINE_CONFIG="../../components/backends/trtllm/engine_configs/decode.yaml"
else
ENGINE_CONFIG="../../components/backends/trtllm/engine_configs/agg.yaml"
fi
EXTRA_ARGS=(
"--extra-engine-args" "$ENGINE_CONFIG"
"--publish-events-and-metrics"
)
else else
# Default args for vLLM engine (explicitly include block-size) # Default args for vLLM engine (explicitly include block-size)
EXTRA_ARGS=( EXTRA_ARGS=(
...@@ -90,8 +123,15 @@ fi ...@@ -90,8 +123,15 @@ fi
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE)) TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE))
LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1)) LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1))
echo "Configuration:" echo "Configuration:"
echo " Engine Type: $([ "$USE_MOCKERS" = true ] && echo "Mocker" || echo "vLLM")" if [ "$USE_MOCKERS" = true ]; then
echo " Worker Type: $([ "$USE_PREFILLS" = true ] && echo "Prefill" || echo "Decode")" ENGINE_TYPE="Mocker"
elif [ "$USE_TRTLLM" = true ]; then
ENGINE_TYPE="TensorRT-LLM"
else
ENGINE_TYPE="vLLM"
fi
echo " Engine Type: $ENGINE_TYPE"
echo " Mode: $MODE"
echo " Workers: $NUM_WORKERS" echo " Workers: $NUM_WORKERS"
echo " Model: $MODEL_PATH" echo " Model: $MODEL_PATH"
echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE" echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
...@@ -111,12 +151,11 @@ cleanup() { ...@@ -111,12 +151,11 @@ cleanup() {
trap cleanup SIGINT SIGTERM trap cleanup SIGINT SIGTERM
WORKER_TYPE=$([ "$USE_PREFILLS" = true ] && echo "prefill" || echo "decode") echo "Starting $NUM_WORKERS $MODE workers..."
echo "Starting $NUM_WORKERS $WORKER_TYPE workers..."
for i in $(seq 1 $NUM_WORKERS); do for i in $(seq 1 $NUM_WORKERS); do
{ {
echo "[${WORKER_TYPE^} Worker-$i] Starting..." echo "[${MODE^} Worker-$i] Starting..."
# Calculate GPU indices for this worker (with base offset) # Calculate GPU indices for this worker (with base offset)
START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * TENSOR_PARALLEL_SIZE )) START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * TENSOR_PARALLEL_SIZE ))
...@@ -142,13 +181,26 @@ for i in $(seq 1 $NUM_WORKERS); do ...@@ -142,13 +181,26 @@ for i in $(seq 1 $NUM_WORKERS); do
--model-path "$MODEL_PATH" \ --model-path "$MODEL_PATH" \
--endpoint dyn://test.mocker.generate \ --endpoint dyn://test.mocker.generate \
"${EXTRA_ARGS[@]}" "${EXTRA_ARGS[@]}"
elif [ "$USE_TRTLLM" = true ]; then
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES"
# Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization
TRTLLM_ARGS=()
TRTLLM_ARGS+=("--model-path" "$MODEL_PATH")
TRTLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$MODE" != "agg" ]; then
TRTLLM_ARGS+=("--disaggregation-mode" "$MODE")
fi
TRTLLM_ARGS+=("${EXTRA_ARGS[@]}")
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \
"${TRTLLM_ARGS[@]}"
else else
echo "[${WORKER_TYPE^} Worker-$i] Using GPUs: $GPU_DEVICES" echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES"
# Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing # Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing
VLLM_ARGS=() VLLM_ARGS=()
VLLM_ARGS+=("--model" "$MODEL_PATH") VLLM_ARGS+=("--model" "$MODEL_PATH")
VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE") VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$USE_PREFILLS" = true ]; then if [ "$MODE" = "prefill" ]; then
VLLM_ARGS+=("--is-prefill-worker") VLLM_ARGS+=("--is-prefill-worker")
fi fi
VLLM_ARGS+=("${EXTRA_ARGS[@]}") VLLM_ARGS+=("${EXTRA_ARGS[@]}")
...@@ -158,7 +210,7 @@ for i in $(seq 1 $NUM_WORKERS); do ...@@ -158,7 +210,7 @@ for i in $(seq 1 $NUM_WORKERS); do
fi fi
} & } &
PIDS+=($!) PIDS+=($!)
echo "Started $WORKER_TYPE worker $i (PID: $!)" echo "Started $MODE worker $i (PID: $!)"
done done
echo "All workers started. Press Ctrl+C to stop." echo "All workers started. Press Ctrl+C to stop."
......
...@@ -98,6 +98,7 @@ class StandaloneRouterHandler: ...@@ -98,6 +98,7 @@ class StandaloneRouterHandler:
"output_options": request.get("output_options", {}), "output_options": request.get("output_options", {}),
"eos_token_ids": request.get("eos_token_ids", []), "eos_token_ids": request.get("eos_token_ids", []),
"annotations": request.get("annotations", []), "annotations": request.get("annotations", []),
"disaggregated_params": request.get("disaggregated_params"),
"extra_args": request.get("extra_args", {}), "extra_args": request.get("extra_args", {}),
} }
...@@ -116,6 +117,7 @@ class StandaloneRouterHandler: ...@@ -116,6 +117,7 @@ class StandaloneRouterHandler:
"top_logprobs": worker_output.get("top_logprobs"), "top_logprobs": worker_output.get("top_logprobs"),
"finish_reason": worker_output.get("finish_reason"), "finish_reason": worker_output.get("finish_reason"),
"index": worker_output.get("index"), "index": worker_output.get("index"),
"disaggregated_params": worker_output.get("disaggregated_params"),
"extra_args": worker_output.get("extra_args"), "extra_args": worker_output.get("extra_args"),
} }
yield llm_engine_output yield llm_engine_output
......
...@@ -138,6 +138,22 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -138,6 +138,22 @@ async def init(runtime: DistributedRuntime, config: Config):
.client() .client()
) )
# Set up prefill router client for decode workers
next_router_client = None
if config.disaggregation_mode.value == "decode":
try:
logging.info("Initializing prefill router client")
next_router_client = (
await runtime.namespace(config.namespace)
.component("router") # Standalone router for prefill workers
.endpoint("generate")
.client()
)
logging.info("Prefill router client initialized successfully")
except Exception as e:
logging.warning(f"Failed to initialize prefill router client: {e}")
logging.info("Will use direct prefill worker client only")
encode_client = None encode_client = None
if config.encode_endpoint: if config.encode_endpoint:
logging.info( logging.info(
...@@ -329,6 +345,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -329,6 +345,7 @@ async def init(runtime: DistributedRuntime, config: Config):
disaggregation_mode=config.disaggregation_mode, disaggregation_mode=config.disaggregation_mode,
disaggregation_strategy=config.disaggregation_strategy, disaggregation_strategy=config.disaggregation_strategy,
next_client=next_client, next_client=next_client,
next_router_client=next_router_client,
encode_client=encode_client, encode_client=encode_client,
multimodal_processor=multimodal_processor, multimodal_processor=multimodal_processor,
connector=connector, connector=connector,
...@@ -357,13 +374,15 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -357,13 +374,15 @@ async def init(runtime: DistributedRuntime, config: Config):
# Get health check payload (checks env var and falls back to TensorRT-LLM default) # Get health check payload (checks env var and falls back to TensorRT-LLM default)
health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer).to_dict() health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer).to_dict()
if config.publish_events_and_metrics and is_first_worker(config): if config.publish_events_and_metrics:
# Initialize and pass in the publisher to the request handler to # Initialize and pass in the publisher to the request handler to
# publish events and metrics. # publish events and metrics.
kv_listener = runtime.namespace(config.namespace).component( kv_listener = runtime.namespace(config.namespace).component(
config.component config.component
) )
metrics_labels = [("model", config.served_model_name)] # Use model_path as fallback if served_model_name is not provided
model_name_for_metrics = config.served_model_name or config.model_path
metrics_labels = [("model", model_name_for_metrics)]
async with get_publisher( async with get_publisher(
component, component,
engine, engine,
......
...@@ -68,6 +68,7 @@ class RequestHandlerConfig: ...@@ -68,6 +68,7 @@ class RequestHandlerConfig:
disaggregation_mode: DisaggregationMode disaggregation_mode: DisaggregationMode
disaggregation_strategy: DisaggregationStrategy disaggregation_strategy: DisaggregationStrategy
next_client: object next_client: object
next_router_client: Optional[object] = None
encode_client: Optional[object] = None encode_client: Optional[object] = None
multimodal_processor: Optional[ multimodal_processor: Optional[
MultimodalRequestProcessor MultimodalRequestProcessor
...@@ -88,6 +89,7 @@ class HandlerBase: ...@@ -88,6 +89,7 @@ class HandlerBase:
self.disaggregation_mode = config.disaggregation_mode self.disaggregation_mode = config.disaggregation_mode
self.disaggregation_strategy = config.disaggregation_strategy self.disaggregation_strategy = config.disaggregation_strategy
self.next_client = config.next_client self.next_client = config.next_client
self.next_router_client = config.next_router_client
self.encode_client = config.encode_client self.encode_client = config.encode_client
self.multimodal_processor = config.multimodal_processor self.multimodal_processor = config.multimodal_processor
self.first_generation = True self.first_generation = True
......
...@@ -191,8 +191,37 @@ class DecodeHandler(HandlerBase): ...@@ -191,8 +191,37 @@ class DecodeHandler(HandlerBase):
super().__init__(config) super().__init__(config)
async def remote_prefill(self, request: dict, context: Context): async def remote_prefill(self, request: dict, context: Context):
async for res in await self.next_client.round_robin(request, context=context): """
yield res Send request to prefill. Try router first if available, fallback to direct worker.
"""
# Format request in PreprocessedRequest format with extra_args
prefill_request = copy.deepcopy(request)
# Try router first if available, fallback to worker
if (
self.next_router_client is not None
and self.next_router_client.instance_ids()
):
try:
# Call router's generate endpoint which returns LLMEngineOutput
async for res in await self.next_router_client.generate(
prefill_request, context=context
):
yield res
return
except Exception as e:
logging.warning(
f"Prefill router call failed: {e}. Falling back to direct worker."
)
# Fallback to direct worker
if self.next_client is not None:
async for res in await self.next_client.round_robin(
prefill_request, context=context
):
yield res
else:
raise ValueError("No prefill router or worker available")
async def generate(self, request: dict, context: Context): async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}") logging.debug(f"New Request ID: {context.id()}")
......
...@@ -271,6 +271,7 @@ fn run_request( ...@@ -271,6 +271,7 @@ fn run_request(
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
disaggregated_params: None,
extra_args: None, extra_args: None,
}; };
work_request work_request
......
...@@ -302,7 +302,16 @@ pub async fn start_kv_router_background( ...@@ -302,7 +302,16 @@ pub async fn start_kv_router_background(
let key = String::from_utf8_lossy(kv.key()); let key = String::from_utf8_lossy(kv.key());
tracing::info!("Router deleted: {}", key); tracing::info!("Router deleted: {}", key);
// Extract the router UUID from the key (format: kv_routers/<model>/<uuid>) // Only process deletions for routers on the same component
if !key.contains(component.path().as_str()) {
tracing::trace!(
"Skipping router deletion from different component (key: {key}, subscriber component: {})",
component.path()
);
continue;
}
// Extract the router UUID from the key
let Some(router_uuid) = key.split('/').next_back() else { let Some(router_uuid) = key.split('/').next_back() else {
tracing::warn!("Could not extract UUID from router key: {}", key); tracing::warn!("Could not extract UUID from router key: {}", key);
continue; continue;
......
...@@ -210,6 +210,7 @@ mod tests { ...@@ -210,6 +210,7 @@ mod tests {
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
disaggregated_params: None,
extra_args: None, extra_args: None,
}) })
} }
......
...@@ -392,6 +392,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -392,6 +392,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
disaggregated_params: None,
extra_args: None, extra_args: None,
}; };
......
...@@ -85,6 +85,10 @@ pub struct LLMEngineOutput { ...@@ -85,6 +85,10 @@ pub struct LLMEngineOutput {
// Index field for batch requests to match OpenAI format // Index field for batch requests to match OpenAI format
pub index: Option<u32>, pub index: Option<u32>,
/// Disaggregated execution parameters (for prefill/decode separation)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
/// Additional arguments for extensibility /// Additional arguments for extensibility
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>, pub extra_args: Option<serde_json::Value>,
...@@ -101,6 +105,7 @@ impl LLMEngineOutput { ...@@ -101,6 +105,7 @@ impl LLMEngineOutput {
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Cancelled), finish_reason: Some(FinishReason::Cancelled),
index: None, index: None,
disaggregated_params: None,
extra_args: None, extra_args: None,
} }
} }
...@@ -115,6 +120,7 @@ impl LLMEngineOutput { ...@@ -115,6 +120,7 @@ impl LLMEngineOutput {
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
top_logprobs: None, top_logprobs: None,
index: None, index: None,
disaggregated_params: None,
extra_args: None, extra_args: None,
} }
} }
...@@ -129,6 +135,7 @@ impl LLMEngineOutput { ...@@ -129,6 +135,7 @@ impl LLMEngineOutput {
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Length), finish_reason: Some(FinishReason::Length),
index: None, index: None,
disaggregated_params: None,
extra_args: None, extra_args: None,
} }
} }
...@@ -143,6 +150,7 @@ impl LLMEngineOutput { ...@@ -143,6 +150,7 @@ impl LLMEngineOutput {
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Error(err_msg)), finish_reason: Some(FinishReason::Error(err_msg)),
index: None, index: None,
disaggregated_params: None,
extra_args: None, extra_args: None,
} }
} }
......
...@@ -60,6 +60,11 @@ pub struct PreprocessedRequest { ...@@ -60,6 +60,11 @@ pub struct PreprocessedRequest {
#[builder(default)] #[builder(default)]
pub router_config_override: Option<RouterConfigOverride>, pub router_config_override: Option<RouterConfigOverride>,
/// Disaggregated execution parameters (for prefill/decode separation)
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
/// Additional arguments for extensibility /// Additional arguments for extensibility
#[builder(default)] #[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment