Unverified Commit 94f100bc authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Router benchmarking (#2828)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 1995ef9a
<!-- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Router Benchmarking Guide
This directory contains scripts for benchmarking the Dynamo router with prefix caching. The benchmarks measure performance improvements from prefix sharing across requests.
## Prerequisites
- NVIDIA GPUs (8 GPUs for default configuration)
- CUDA environment properly configured
- etcd and NATS running (required for Dynamo coordination)
- Required Python packages:
- `dynamo` package (with vllm and frontend modules)
- `genai-perf` for benchmarking
- `matplotlib` for plotting results
### Setting up etcd and NATS
This benchmark requires etcd and NATS. To quickly set them up, run:
```bash
# From the repository root:
docker compose -f deploy/docker-compose.yml up -d
```
This will start both etcd and NATS with the required configurations in the background.
## Scripts Overview
- **`run_engines.sh`** - Launches multiple vLLM worker instances
- **`ping.sh`** - Simple test script to verify the setup is working
- **`prefix_ratio_benchmark.py`** - Main benchmarking script that sweeps prefix ratios
- **`plot_prefix_ratio_comparison.py`** - Generates comparison plots from benchmark results
## Usage Instructions
### Step 1: Launch vLLM Workers
First, start the vLLM worker engines in a terminal.
```bash
# Default: 8 vLLM workers with DeepSeek model (explicitly sets --block-size 64)
./run_engines.sh \
--num-workers 8 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Example: 4 vLLM workers with larger model using tensor parallelism (2 GPUs per worker)
./run_engines.sh \
--num-workers 4 \
--model-path openai/gpt-oss-120b \
--tensor-parallel-size 2
```
#### Alternative: Launch vLLM Mock Workers
We also supports running lightweight mock engines that simulate vLLM behavior without performing actual model inference. Mocker engines are useful for testing router logic and performance without GPU requirements. Use the `--mockers` flag to run mocker engines instead of real vLLM workers.
```bash
# Example: Running mocker engines for testing (no GPU required)
./run_engines.sh --mockers \
--num-workers 8 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--block-size 64 \
--speedup-ratio 2.0
```
**Note**: The `--speedup-ratio` parameter controls the inference speed of mocker engines. A higher value (e.g., 2.0) makes the mocker engines simulate faster inference, allowing benchmarks to complete more quickly. This is particularly useful for testing router performance without waiting for realistic inference times.
### Step 2: Start the Router
In a **new terminal**, launch the Dynamo router using the Python CLI:
```bash
python -m dynamo.frontend \
--router-mode kv \
--kv-cache-block-size 64 \
--router-reset-states \
--http-port 8000
```
This starts the router with:
- KV cache routing mode
- Block size of 64 (**Important:** This should match the `--block-size` used by your engines)
- `--router-reset-states` flag to clear the event cache (JetStream) from previous runs (useful for single router benchmarking)
- HTTP port 8000
To see all available router arguments, run:
```bash
python -m dynamo.frontend --help
```
For detailed explanations of router arguments (especially KV cache routing parameters), see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md).
**Note**: If you're unsure whether your backend engines correctly emit KV events for certain models (e.g., hybrid models like gpt-oss or nemotron nano 2), use the `--no-kv-events` flag to disable KV event tracking and use approximate KV indexing instead:
```bash
python -m dynamo.frontend \
--router-mode kv \
--kv-cache-block-size 64 \
--http-port 8000 \
--no-kv-events
```
### Step 3: Verify Setup
In another terminal, test that everything is working:
```bash
./ping.sh
# Or specify a different port:
./ping.sh 8000
```
This sends a simple test request to the router. You should see a streamed response if everything is configured correctly.
### Step 4: Run Benchmarks
Once the setup is verified, run the prefix ratio benchmark:
```bash
python prefix_ratio_benchmark.py
```
Default configuration:
- Tests prefix ratios: 0.5 (can be customized with `--prefix-ratios 0.1 0.3 0.5 0.7 0.9`)
- Input sequence length: 14000 tokens
- Output sequence length: 200 tokens
- Requests: 200
- Concurrency: 20
You can customize the benchmark:
```bash
# Test multiple prefix ratios
python prefix_ratio_benchmark.py --prefix-ratios 0.1 0.3 0.5 0.7 0.9
# Adjust input/output lengths
python prefix_ratio_benchmark.py --isl 10000 --osl 500
# Change request count and concurrency
python prefix_ratio_benchmark.py --requests 500 --concurrency 50
# Use multiple router endpoints for parallel benchmarking (for testing multiple Router replicas)
python prefix_ratio_benchmark.py --url http://localhost:8000 http://localhost:8001
# Specify output directory
python prefix_ratio_benchmark.py --output-dir results/experiment1
```
### Benchmark Output
The benchmark script generates:
1. **Performance plots** (`prefix_ratio_performance.png`):
- TTFT (Time to First Token) vs Prefix Ratio
- Throughput (tokens/s) vs Prefix Ratio
2. **Results summary** (`results_summary.json`):
- Raw data for all prefix ratios tested
- Configuration parameters used
3. **Detailed artifacts** (in subdirectories):
- Full genai-perf profiling data for each run
## Troubleshooting
1. **Workers fail to start**: Check CUDA_VISIBLE_DEVICES and GPU availability
2. **Router connection refused**: Ensure router is running and port is correct
3. **Benchmark timeout**: Decrease concurrency or reduce request count
4. **OOM errors**: Reduce max-num-batched-tokens or max-model-len in run_engines.sh
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Get port from first argument, default to 8080 if not provided
PORT=${1:-8080}
curl -X POST http://localhost:${PORT}/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Accept: text/event-stream" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{"role": "system", "content": "You are a helpful assistant. Answer in 5 words."},
{"role": "user", "content": "What is 2+2?"}
],
"stream": true,
"max_tokens": 10,
"ignore_eos": true
}'
\ No newline at end of file
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
import logging
import os
import subprocess
from typing import Dict, List, Optional
import matplotlib
matplotlib.use("Agg") # Use non-interactive backend
import matplotlib.pyplot as plt
# Setup logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
def get_genai_perf_cmd(
model,
tokenizer, # Add tokenizer parameter
prefix_ratio,
isl,
osl,
requests,
concurrency,
seed,
num_prefix_prompts,
artifact_dir,
url="http://localhost:8888",
):
"""Build genai-perf command based on prefix ratio"""
prefix_length = int(isl * prefix_ratio)
synthetic_input_length = int(isl * (1 - prefix_ratio))
return [
"genai-perf",
"profile",
"--model",
model,
"--tokenizer",
tokenizer, # Use the tokenizer parameter instead of model
"--endpoint-type",
"chat",
"--endpoint",
"v1/chat/completions",
"--streaming",
"--url",
url,
"--synthetic-input-tokens-mean",
str(synthetic_input_length),
"--synthetic-input-tokens-stddev",
str(round(synthetic_input_length / 4)),
"--output-tokens-mean",
str(osl),
"--output-tokens-stddev",
str(round(osl / 4)),
"--extra-inputs",
"ignore_eos:true",
"--extra-inputs",
'{"nvext":{"ignore_eos":true}}',
"--concurrency",
str(concurrency),
"--request-count",
str(requests),
"--num-dataset-entries",
str(requests),
"--random-seed",
str(seed),
"--prefix-prompt-length",
str(prefix_length),
"--num-prefix-prompts",
str(num_prefix_prompts),
"--artifact-dir",
artifact_dir,
"--",
"-v",
"--max-threads",
"256",
"-H",
"Authorization: Bearer NOT USED",
"-H",
"Accept: text/event-stream",
]
def get_gap_result(artifact_dir: str) -> dict:
"""Parse genai-perf results from JSON file"""
json_file_path = None
for root, _, files in os.walk(artifact_dir):
if "profile_export_genai_perf.json" in files:
json_file_path = os.path.join(root, "profile_export_genai_perf.json")
break
if json_file_path is None:
raise FileNotFoundError(
f"profile_export_genai_perf.json not found in {artifact_dir}"
)
with open(json_file_path, "r") as f:
return json.load(f)
def run_benchmark_single_url(
model,
tokenizer, # Add tokenizer parameter
prefix_ratio,
isl,
osl,
requests,
concurrency,
seed,
num_prefix_prompts,
artifact_dir,
url,
) -> Optional[Dict]:
"""Run genai-perf benchmark for a single URL"""
genai_perf_cmd = get_genai_perf_cmd(
model,
tokenizer, # Pass tokenizer parameter
prefix_ratio,
isl,
osl,
requests,
concurrency,
seed,
num_prefix_prompts,
artifact_dir,
url,
)
logger.info(f"Running command for URL {url}: {' '.join(genai_perf_cmd)}")
try:
gap_process = subprocess.run(
genai_perf_cmd, capture_output=True, text=True, check=True
)
logger.info(f"Genai-perf profiling completed successfully for URL {url}")
logger.info(gap_process.stdout)
gap_result = get_gap_result(artifact_dir)
return gap_result
except subprocess.CalledProcessError as e:
logger.error(f"Genai-perf failed for URL {url} with error code: {e.returncode}")
logger.error(f"stderr: {e.stderr}")
return None
def aggregate_results(results: List[Optional[Dict]]) -> Optional[Dict]:
"""Aggregate results from multiple URLs"""
if not results:
return None
# For TTFT, we take the average across all URLs
# For throughput, we sum across all URLs (total system throughput)
ttft_values = [r["time_to_first_token"]["avg"] for r in results if r is not None]
throughput_values = [
r["output_token_throughput"]["avg"] for r in results if r is not None
]
if not ttft_values or not throughput_values:
return None
aggregated = {
"time_to_first_token": {"avg": sum(ttft_values) / len(ttft_values)},
"output_token_throughput": {
"avg": sum(throughput_values) # Total throughput across all URLs
},
}
return aggregated
def run_benchmark(
model,
tokenizer, # Add tokenizer parameter
prefix_ratio,
isl,
osl,
requests,
concurrency,
seed,
num_prefix_prompts,
output_dir,
urls,
) -> Optional[Dict]:
"""Run genai-perf benchmark for a specific prefix ratio"""
logger.info(
f"Running benchmark with prefix_ratio={prefix_ratio}, seed={seed}, URLs={urls}"
)
# If single URL, maintain existing behavior
if isinstance(urls, str):
urls = [urls]
if len(urls) == 1:
artifact_dir = f"{output_dir}/prefix_ratio_{prefix_ratio}_seed_{seed}"
os.makedirs(artifact_dir, exist_ok=True)
return run_benchmark_single_url(
model,
tokenizer, # Pass tokenizer parameter
prefix_ratio,
isl,
osl,
requests,
concurrency,
seed,
num_prefix_prompts,
artifact_dir,
urls[0],
)
# Multiple URLs: split requests and concurrency
num_urls = len(urls)
base_requests_per_url = requests // num_urls
remainder_requests = requests % num_urls
base_concurrency_per_url = max(1, concurrency // num_urls)
# Launch parallel processes
processes = []
artifact_dirs = []
for i, url in enumerate(urls):
# Distribute remainder requests to first few URLs
url_requests = base_requests_per_url + (1 if i < remainder_requests else 0)
artifact_dir = f"{output_dir}/prefix_ratio_{prefix_ratio}_seed_{seed}_url_{i}"
os.makedirs(artifact_dir, exist_ok=True)
artifact_dirs.append(artifact_dir)
genai_perf_cmd = get_genai_perf_cmd(
model,
tokenizer, # Pass tokenizer parameter
prefix_ratio,
isl,
osl,
url_requests,
base_concurrency_per_url,
seed,
num_prefix_prompts,
artifact_dir,
url,
)
logger.info(f"Launching process for URL {url}: {' '.join(genai_perf_cmd)}")
process = subprocess.Popen(
genai_perf_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
processes.append((process, url, artifact_dir))
# Wait for all processes to complete and collect results
results: List[Optional[Dict]] = []
for process, url, artifact_dir in processes:
stdout, stderr = process.communicate()
if process.returncode == 0:
logger.info(f"Genai-perf completed successfully for URL {url}")
logger.info(stdout)
try:
gap_result = get_gap_result(artifact_dir)
results.append(gap_result)
except Exception as e:
logger.error(f"Failed to get results for URL {url}: {e}")
results.append(None)
else:
logger.error(
f"Genai-perf failed for URL {url} with error code: {process.returncode}"
)
logger.error(f"stderr: {stderr}")
results.append(None)
# Aggregate results
return aggregate_results(results)
def main():
parser = argparse.ArgumentParser(
description="Benchmark prefix ratios and plot results"
)
parser.add_argument(
"--model",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model name",
)
parser.add_argument(
"--tokenizer",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Tokenizer name (defaults to model)",
)
parser.add_argument(
"--url",
type=str,
nargs="+", # Accept multiple URLs
default=["http://localhost:8080"],
# default=["http://localhost:8090", "http://localhost:8090"],
help="Server URL(s). Can specify multiple URLs for parallel benchmarking",
)
parser.add_argument(
"--output-dir",
type=str,
default="kv_router",
help="Output directory for results",
)
parser.add_argument("--num-prefix-prompts", type=int, default=20)
parser.add_argument("--isl", type=int, default=14000, help="Input sequence length")
parser.add_argument("--osl", type=int, default=200, help="Output sequence length")
parser.add_argument("--requests", type=int, default=200, help="Number of requests")
parser.add_argument("--concurrency", type=int, default=20, help="Concurrency level")
parser.add_argument("--seed", type=int, default=420, help="Initial random seed")
parser.add_argument(
"--prefix-ratios",
type=float,
nargs="+",
default=[0.1, 0.3, 0.5, 0.7, 0.9],
help="List of prefix ratios to test",
)
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Store results
prefix_ratios = []
ttft_values = []
throughput_values = []
current_seed = args.seed
# Run benchmarks for each prefix ratio
for prefix_ratio in args.prefix_ratios:
result = run_benchmark(
args.model,
args.tokenizer,
prefix_ratio,
args.isl,
args.osl,
args.requests,
args.concurrency,
current_seed,
args.num_prefix_prompts,
args.output_dir,
args.url, # Now passing list of URLs
)
if result is not None:
ttft = result["time_to_first_token"]["avg"]
throughput = result["output_token_throughput"]["avg"]
prefix_ratios.append(prefix_ratio)
ttft_values.append(ttft)
throughput_values.append(throughput)
logger.info(
f"Prefix ratio {prefix_ratio}: TTFT={ttft:.2f}ms, Throughput={throughput:.2f} tokens/s"
)
current_seed += 1
# Create plots
if prefix_ratios and ttft_values and throughput_values:
# Plot TTFT vs Prefix Ratio
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(prefix_ratios, ttft_values, "bo-", linewidth=2, markersize=8)
plt.xlabel("Prefix Ratio")
plt.ylabel("Time to First Token (ms)")
plt.title("TTFT vs Prefix Ratio")
plt.grid(True, alpha=0.3)
for i, (pr, ttft) in enumerate(zip(prefix_ratios, ttft_values)):
plt.annotate(
f"{ttft:.1f}ms",
(pr, ttft),
textcoords="offset points",
xytext=(0, 10),
ha="center",
)
# Plot Throughput vs Prefix Ratio
plt.subplot(1, 2, 2)
plt.plot(prefix_ratios, throughput_values, "ro-", linewidth=2, markersize=8)
plt.xlabel("Prefix Ratio")
plt.ylabel("Output Token Throughput (tokens/s)")
plt.title("Throughput vs Prefix Ratio")
plt.grid(True, alpha=0.3)
for i, (pr, thpt) in enumerate(zip(prefix_ratios, throughput_values)):
plt.annotate(
f"{thpt:.1f}",
(pr, thpt),
textcoords="offset points",
xytext=(0, 10),
ha="center",
)
plt.tight_layout()
# Save plot
plot_path = f"{args.output_dir}/prefix_ratio_performance.png"
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
logger.info(f"Performance plot saved to {plot_path}")
# Save results to JSON
results_data = {
"prefix_ratios": prefix_ratios,
"ttft_values": ttft_values,
"throughput_values": throughput_values,
"config": {
"model": args.model,
"tokenizer": args.tokenizer,
"isl": args.isl,
"osl": args.osl,
"requests": args.requests,
"concurrency": args.concurrency,
"initial_seed": args.seed,
},
}
results_path = f"{args.output_dir}/results_summary.json"
with open(results_path, "w") as f:
json.dump(results_data, f, indent=2)
logger.info(f"Results summary saved to {results_path}")
else:
logger.error("No successful benchmark results to plot")
if __name__ == "__main__":
main()
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Parse command-line arguments
NUM_WORKERS=8
MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
TENSOR_PARALLEL_SIZE=1
USE_MOCKERS=false
EXTRA_ARGS=()
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
--num-workers)
NUM_WORKERS="$2"
shift 2
;;
--model-path)
MODEL_PATH="$2"
shift 2
;;
--tensor-parallel-size)
TENSOR_PARALLEL_SIZE="$2"
shift 2
;;
--mockers)
USE_MOCKERS=true
shift
;;
--)
shift
EXTRA_ARGS+=("$@")
break
;;
*)
# Collect all other arguments as vLLM/mocker arguments
EXTRA_ARGS+=("$1")
shift
;;
esac
done
# If no extra args provided, use defaults
if [ ${#EXTRA_ARGS[@]} -eq 0 ]; then
if [ "$USE_MOCKERS" = true ]; then
# Default args for mocker engine (only block-size needed as others are defaults)
EXTRA_ARGS=(
"--block-size" "64"
)
else
# Default args for vLLM engine (explicitly include block-size)
EXTRA_ARGS=(
"--enforce-eager"
"--max-num-batched-tokens" "16384"
"--max-model-len" "32768"
"--block-size" "64"
)
fi
fi
# Validate arguments
if ! [[ "$NUM_WORKERS" =~ ^[0-9]+$ ]] || [ "$NUM_WORKERS" -lt 1 ]; then
echo "Error: NUM_WORKERS must be a positive integer"
exit 1
fi
if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt 1 ]; then
echo "Error: TENSOR_PARALLEL_SIZE must be a positive integer"
exit 1
fi
# Calculate total GPUs needed
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE))
echo "Configuration:"
echo " Engine Type: $([ "$USE_MOCKERS" = true ] && echo "Mocker" || echo "vLLM")"
echo " Workers: $NUM_WORKERS"
echo " Model: $MODEL_PATH"
echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
echo " Total GPUs needed: $TOTAL_GPUS_NEEDED"
echo " Engine args: ${EXTRA_ARGS[*]}"
echo ""
PIDS=()
cleanup() {
echo -e "\nStopping all workers..."
kill "${PIDS[@]}" 2>/dev/null
wait
exit 0
}
trap cleanup SIGINT SIGTERM
echo "Starting $NUM_WORKERS workers..."
for i in $(seq 1 $NUM_WORKERS); do
{
echo "[Worker-$i] Starting..."
# Calculate GPU indices for this worker
START_GPU=$(( (i - 1) * TENSOR_PARALLEL_SIZE ))
END_GPU=$(( START_GPU + TENSOR_PARALLEL_SIZE - 1 ))
# Build CUDA_VISIBLE_DEVICES string
if [ "$TENSOR_PARALLEL_SIZE" -eq 1 ]; then
GPU_DEVICES="$START_GPU"
else
GPU_DEVICES=""
for gpu in $(seq $START_GPU $END_GPU); do
if [ -n "$GPU_DEVICES" ]; then
GPU_DEVICES="${GPU_DEVICES},$gpu"
else
GPU_DEVICES="$gpu"
fi
done
fi
if [ "$USE_MOCKERS" = true ]; then
# Run mocker engine (no GPU assignment needed)
exec python -m dynamo.mocker \
--model-path "$MODEL_PATH" \
--endpoint dyn://test.mocker.generate \
"${EXTRA_ARGS[@]}"
else
echo "[Worker-$i] Using GPUs: $GPU_DEVICES"
# Run vLLM engine (exec with env for proper syntax)
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES python -m dynamo.vllm \
--model "$MODEL_PATH" \
--endpoint dyn://test.vllm.generate \
--tensor-parallel-size $TENSOR_PARALLEL_SIZE \
"${EXTRA_ARGS[@]}"
fi
} &
PIDS+=($!)
echo "Started worker $i (PID: $!)"
done
echo "All workers started. Press Ctrl+C to stop."
wait
echo "All workers completed."
\ No newline at end of file
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::component::{Component, Instance}; use dynamo_runtime::component::{Component, Instance};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher; use dynamo_runtime::traits::events::EventPublisher;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::watch; use tokio::sync::{RwLock, watch};
use super::KV_HIT_RATE_SUBJECT; use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig; use super::KvRouterConfig;
...@@ -91,93 +92,115 @@ impl KvScheduler { ...@@ -91,93 +92,115 @@ impl KvScheduler {
pub async fn start( pub async fn start(
component: Component, component: Component,
block_size: u32, block_size: u32,
mut instances_rx: watch::Receiver<Vec<Instance>>, instances_rx: watch::Receiver<Vec<Instance>>,
mut runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>, runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool, replica_sync: bool,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let mut instances: Vec<Instance> = instances_rx.borrow_and_update().clone(); let instances: Vec<Instance> = instances_rx.borrow().clone();
let mut runtime_configs: HashMap<i64, ModelRuntimeConfig> = let runtime_configs: HashMap<i64, ModelRuntimeConfig> = runtime_configs_rx.borrow().clone();
runtime_configs_rx.borrow_and_update().clone();
// Create shared workers_with_configs wrapped in Arc<RwLock>
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>(); let workers_with_configs: Arc<RwLock<HashMap<i64, Option<ModelRuntimeConfig>>>> = {
let ns_clone = component.namespace().clone(); let mut initial_map = HashMap::new();
tokio::spawn(async move { for instance in &instances {
let mut event_rx = event_rx; let worker_id = instance.instance_id;
while let Some(event) = event_rx.recv().await { let config = runtime_configs.get(&worker_id).cloned();
if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await { if config.is_some() {
tracing::warn!("Failed to publish KV hit rate event: {:?}", e); tracing::info!("Runtime config found for worker_id: {}", worker_id);
} }
initial_map.insert(worker_id, config);
} }
}); Arc::new(RwLock::new(initial_map))
};
let worker_ids: Vec<i64> = instances let worker_ids: Vec<i64> = instances
.iter() .iter()
.map(|instance| instance.instance_id) .map(|instance| instance.instance_id)
.collect(); .collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new( let slots = Arc::new(ActiveSequencesMultiWorker::new(
component, component.clone(),
block_size as usize, block_size as usize,
worker_ids, worker_ids,
replica_sync, replica_sync,
)); ));
let slots_clone = slots.clone(); // Spawn background task to monitor and update workers_with_configs
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024); let workers_monitor = workers_with_configs.clone();
// Background task to handle scheduling requests let slots_monitor = slots.clone();
let mut instances_monitor_rx = instances_rx.clone();
let mut configs_monitor_rx = runtime_configs_rx.clone();
let monitor_cancel_token = component.drt().primary_token();
tokio::spawn(async move { tokio::spawn(async move {
let mut request_rx = request_rx; tracing::trace!("workers monitoring task started");
tracing::trace!("scheduler background task started");
let mut workers_with_configs: HashMap<i64, Option<ModelRuntimeConfig>> = HashMap::new();
let mut needs_rebuild = true;
loop { loop {
// Check for instance updates (non-blocking) // Wait for either instances or configs to change
let instances_changed = instances_rx.has_changed(); tokio::select! {
let configs_changed = runtime_configs_rx.has_changed(); _ = monitor_cancel_token.cancelled() => {
tracing::trace!("workers monitoring task shutting down");
match instances_changed {
Ok(true) => {
instances = instances_rx.borrow_and_update().clone();
let worker_ids: Vec<i64> = instances
.iter()
.map(|instance| instance.instance_id)
.collect();
slots_clone.update_workers(worker_ids);
needs_rebuild = true;
}
Ok(false) => {}
Err(_) => {
tracing::warn!("endpoint watch sender shutdown");
break; break;
} }
} result = instances_monitor_rx.changed() => {
if result.is_err() {
// Check for runtime config updates tracing::warn!("endpoint watch sender shutdown in monitor");
match configs_changed { break;
Ok(true) => { }
runtime_configs = runtime_configs_rx.borrow_and_update().clone();
needs_rebuild = true;
} }
Ok(false) => {} result = configs_monitor_rx.changed() => {
Err(_) => { if result.is_err() {
tracing::warn!("runtime configs watch sender shutdown"); tracing::warn!("runtime configs watch sender shutdown in monitor");
break;
}
} }
} }
// Rebuild workers hashmap only when needed // Get the latest values from both channels
if needs_rebuild { let new_instances = instances_monitor_rx.borrow_and_update().clone();
workers_with_configs.clear(); let new_configs = configs_monitor_rx.borrow_and_update().clone();
for instance in &instances {
let worker_id = instance.instance_id; // Update workers when instances change
let config = runtime_configs.get(&worker_id).cloned(); let worker_ids: Vec<i64> = new_instances
if config.is_none() { .iter()
tracing::warn!("Runtime config not found for worker_id: {}", worker_id); .map(|instance| instance.instance_id)
} .collect();
workers_with_configs.insert(worker_id, config); slots_monitor.update_workers(worker_ids);
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
workers_map.clear();
for instance in &new_instances {
let worker_id = instance.instance_id;
let config = new_configs.get(&worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
} }
needs_rebuild = false; workers_map.insert(worker_id, config);
}
tracing::trace!(
"Updated workers_with_configs with {} workers",
workers_map.len()
);
}
tracing::trace!("workers monitoring task shutting down");
});
let slots_clone = slots.clone();
let workers_scheduler = workers_with_configs.clone();
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token();
let ns_clone = component.namespace().clone();
// Background task to handle scheduling requests
tokio::spawn(async move {
let mut request_rx = request_rx;
tracing::trace!("scheduler background task started");
loop {
// Check for cancellation at beginning of loop
if scheduler_cancel_token.is_cancelled() {
tracing::trace!("scheduler background task shutting down");
break;
} }
// Wait for a new request // Wait for a new request
...@@ -197,14 +220,18 @@ impl KvScheduler { ...@@ -197,14 +220,18 @@ impl KvScheduler {
request.decode_blocks = decode_blocks; request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens; request.prefill_tokens = prefill_tokens;
match selector.select_worker(&workers_with_configs, &request, block_size) { // Read the current workers configuration
let workers = workers_scheduler.read().await.clone();
match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => { Ok(selection) => {
if let Err(e) = event_tx.send(KVHitRateEvent { let event = KVHitRateEvent {
worker_id: selection.worker_id, worker_id: selection.worker_id,
isl_blocks: selection.required_blocks as usize, isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks, overlap_blocks: selection.overlap_blocks,
}) { };
tracing::warn!("Failed to send KV hit rate event: {:?}", e); if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await {
tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
} }
let response = SchedulingResponse { let response = SchedulingResponse {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment