Unverified Commit f978f4d1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: dp rank routing (#3597)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 29f5b822
......@@ -7,6 +7,7 @@
NUM_WORKERS=8
MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
TENSOR_PARALLEL_SIZE=1
DATA_PARALLEL_SIZE=1
USE_MOCKERS=false
USE_TRTLLM=false
MODE="agg" # Options: agg (default), decode, prefill
......@@ -28,6 +29,10 @@ while [[ $# -gt 0 ]]; do
TENSOR_PARALLEL_SIZE="$2"
shift 2
;;
--data-parallel-size)
DATA_PARALLEL_SIZE="$2"
shift 2
;;
--mockers)
USE_MOCKERS=true
shift
......@@ -114,13 +119,19 @@ if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt
exit 1
fi
if ! [[ "$DATA_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$DATA_PARALLEL_SIZE" -lt 1 ]; then
echo "Error: DATA_PARALLEL_SIZE must be a positive integer"
exit 1
fi
if ! [[ "$BASE_GPU_OFFSET" =~ ^[0-9]+$ ]]; then
echo "Error: BASE_GPU_OFFSET must be a non-negative integer"
exit 1
fi
# Calculate total GPUs needed
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE))
# Calculate total GPUs needed (TP * DP per worker)
GPUS_PER_WORKER=$((TENSOR_PARALLEL_SIZE * DATA_PARALLEL_SIZE))
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * GPUS_PER_WORKER))
LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1))
echo "Configuration:"
if [ "$USE_MOCKERS" = true ]; then
......@@ -135,6 +146,8 @@ echo " Mode: $MODE"
echo " Workers: $NUM_WORKERS"
echo " Model: $MODEL_PATH"
echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
echo " Data Parallel Size: $DATA_PARALLEL_SIZE"
echo " GPUs per worker: $GPUS_PER_WORKER"
echo " Total GPUs needed: $TOTAL_GPUS_NEEDED"
echo " GPU Range: $BASE_GPU_OFFSET-$LAST_GPU"
echo " Engine args: ${EXTRA_ARGS[*]}"
......@@ -155,14 +168,16 @@ echo "Starting $NUM_WORKERS $MODE workers..."
for i in $(seq 1 $NUM_WORKERS); do
{
echo "[${MODE^} Worker-$i] Starting..."
MODE_CAPITALIZED=$(echo "$MODE" | sed 's/\(.\)/\U\1/')
echo "[$MODE_CAPITALIZED Worker-$i] Starting..."
# Calculate GPU indices for this worker (with base offset)
START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * TENSOR_PARALLEL_SIZE ))
END_GPU=$(( START_GPU + TENSOR_PARALLEL_SIZE - 1 ))
# Each worker needs TP * DP GPUs
START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * GPUS_PER_WORKER ))
END_GPU=$(( START_GPU + GPUS_PER_WORKER - 1 ))
# Build CUDA_VISIBLE_DEVICES string
if [ "$TENSOR_PARALLEL_SIZE" -eq 1 ]; then
# Build CUDA_VISIBLE_DEVICES string for all GPUs (TP * DP)
if [ "$GPUS_PER_WORKER" -eq 1 ]; then
GPU_DEVICES="$START_GPU"
else
GPU_DEVICES=""
......@@ -177,12 +192,17 @@ for i in $(seq 1 $NUM_WORKERS); do
if [ "$USE_MOCKERS" = true ]; then
# Run mocker engine (no GPU assignment needed)
exec python -m dynamo.mocker \
--model-path "$MODEL_PATH" \
--endpoint dyn://test.mocker.generate \
"${EXTRA_ARGS[@]}"
MOCKER_ARGS=()
MOCKER_ARGS+=("--model-path" "$MODEL_PATH")
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
MOCKER_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
MOCKER_ARGS+=("${EXTRA_ARGS[@]}")
exec python -m dynamo.mocker "${MOCKER_ARGS[@]}"
elif [ "$USE_TRTLLM" = true ]; then
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES"
echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
# Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization
TRTLLM_ARGS=()
TRTLLM_ARGS+=("--model-path" "$MODEL_PATH")
......@@ -195,11 +215,14 @@ for i in $(seq 1 $NUM_WORKERS); do
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \
"${TRTLLM_ARGS[@]}"
else
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES"
echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
# Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing
VLLM_ARGS=()
VLLM_ARGS+=("--model" "$MODEL_PATH")
VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
VLLM_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
if [ "$MODE" = "prefill" ]; then
VLLM_ARGS+=("--is-prefill-worker")
fi
......
......@@ -99,6 +99,7 @@ class StandaloneRouterHandler:
"eos_token_ids": request.get("eos_token_ids", []),
"annotations": request.get("annotations", []),
"disaggregated_params": request.get("disaggregated_params"),
"dp_rank": request.get("dp_rank"),
"extra_args": request.get("extra_args", {}),
}
......
......@@ -33,7 +33,7 @@ class BaseWorkerHandler(ABC):
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.kv_publisher = None
self.kv_publishers = None
self.engine_monitor = VllmEngineMonitor(runtime, engine)
@abstractmethod
......@@ -81,9 +81,16 @@ class BaseWorkerHandler(ABC):
"""Override in subclasses if cleanup is needed."""
pass
async def generate_tokens(self, prompt, sampling_params, request_id):
async def generate_tokens(
self, prompt, sampling_params, request_id, data_parallel_rank=None
):
try:
gen = self.engine_client.generate(prompt, sampling_params, request_id)
gen = self.engine_client.generate(
prompt,
sampling_params,
request_id,
data_parallel_rank=data_parallel_rank,
)
num_output_tokens_so_far = 0
try:
......@@ -211,10 +218,12 @@ class DecodeWorkerHandler(BaseWorkerHandler):
return
logger.warning(f"Prefill error: {e}, falling back to local prefill")
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id):
try:
async for tok in self.generate_tokens(
prompt, sampling_params, request_id
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
):
yield tok
except EngineDeadError as e:
......@@ -241,9 +250,13 @@ class PrefillWorkerHandler(BaseWorkerHandler):
sampling_params_dict = extra_args.get("sampling_params", {})
sampling_params = msgspec.convert(sampling_params_dict, SamplingParams)
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id, is_prefill=True):
try:
gen = self.engine_client.generate(prompt, sampling_params, request_id)
gen = self.engine_client.generate(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
)
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
......
......@@ -5,7 +5,6 @@ import asyncio
import logging
import os
import signal
from typing import Optional
import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher
......@@ -107,21 +106,26 @@ def setup_kv_event_publisher(
component,
generate_endpoint,
vllm_config,
) -> Optional[ZmqKvEventPublisher]:
):
"""
Set up KV event publisher for prefix caching if enabled.
Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Returns:
ZmqKvEventPublisher if prefix caching is enabled, None otherwise.
List of ZmqKvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
"""
if not config.engine_args.enable_prefix_caching:
return None
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
# Get data_parallel_size to create publishers for all dp_ranks
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
kv_publishers = []
for dp_rank in range(data_parallel_size):
# Each dp_rank publishes to a different port
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=config.engine_args.data_parallel_rank or 0,
data_parallel_rank=dp_rank,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
......@@ -130,10 +134,13 @@ def setup_kv_event_publisher(
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
kv_publishers.append(kv_publisher)
logger.info(f"Worker reading KV events from {zmq_endpoint}")
logger.info(
f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
)
return kv_publisher
return kv_publishers if kv_publishers else None
def setup_vllm_engine(config, stat_logger=None):
......@@ -200,12 +207,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
runtime, component, engine_client, default_sampling_params
)
# Set up KV event publisher for prefix caching if enabled
kv_publisher = setup_kv_event_publisher(
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
if kv_publishers:
handler.kv_publishers = kv_publishers
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
......@@ -285,12 +292,12 @@ async def init(runtime: DistributedRuntime, config: Config):
prefill_router_client,
)
# Set up KV event publisher for prefix caching if enabled
kv_publisher = setup_kv_event_publisher(
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
if kv_publishers:
handler.kv_publishers = kv_publishers
if config.engine_args.disable_log_stats is False:
from prometheus_client import REGISTRY
......@@ -311,6 +318,12 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(
vllm_config.parallel_config, "data_parallel_size", 1
)
runtime_config.data_parallel_size = data_parallel_size
await register_llm(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
......
......@@ -322,7 +322,11 @@ The `KvPushRouter` provides the following methods:
- **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.
- **`best_worker_id(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, overlap_blocks)`.
- **`best_worker(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, dp_rank, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
- **`best_worker_id(token_ids, router_config_override=None, request_id=None)`**: **[DEPRECATED - use `best_worker()` instead]** Query which worker would be selected for given tokens. Returns `(worker_id, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
......
......@@ -207,6 +207,7 @@ fn kv_event_create_stored_from_parts(
parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
}),
event_id: kv_params.event_id,
dp_rank: 0,
}
}
......@@ -224,6 +225,7 @@ fn kv_event_create_removed_from_parts(
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: 0,
}
}
......
......@@ -101,7 +101,7 @@ impl WorkerMetricsPublisher {
#[derive(Clone)]
pub struct ZmqKvEventPublisherConfig {
#[pyo3(get, set)]
pub worker_id: i64,
pub worker_id: WorkerId,
#[pyo3(get, set)]
pub kv_block_size: usize,
#[pyo3(get, set)]
......@@ -120,7 +120,7 @@ impl ZmqKvEventPublisherConfig {
zmq_topic = "".to_string()
))]
pub fn new(
worker_id: i64,
worker_id: WorkerId,
kv_block_size: usize,
zmq_endpoint: String,
zmq_topic: String,
......@@ -234,13 +234,20 @@ impl Drop for ZmqKvEventListener {
pub(crate) struct KvEventPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
kv_block_size: usize,
dp_rank: DpRank,
warning_count: Arc<AtomicU32>,
}
#[pymethods]
impl KvEventPublisher {
#[new]
fn new(component: Component, worker_id: i64, kv_block_size: usize) -> PyResult<Self> {
#[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))]
fn new(
component: Component,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
......@@ -256,6 +263,7 @@ impl KvEventPublisher {
Ok(Self {
inner: inner.into(),
kv_block_size,
dp_rank,
warning_count: Arc::new(AtomicU32::new(0)),
})
}
......@@ -286,6 +294,7 @@ impl KvEventPublisher {
&self.warning_count,
),
}),
dp_rank: self.dp_rank,
};
self.inner.publish(event).map_err(to_pyerr)
......@@ -299,6 +308,7 @@ impl KvEventPublisher {
let event = KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: self.dp_rank,
};
self.inner.publish(event).map_err(to_pyerr)
......@@ -314,8 +324,13 @@ pub(crate) struct OverlapScores {
#[pymethods]
impl OverlapScores {
#[getter]
fn scores(&self) -> HashMap<llm_rs::kv_router::indexer::WorkerId, u32> {
self.inner.scores.clone()
fn scores(&self) -> HashMap<(i64, u32), u32> {
// Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
self.inner
.scores
.iter()
.map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
.collect()
}
#[getter]
......@@ -361,7 +376,7 @@ impl RadixTree {
fn apply_event(
&mut self,
_py: Python,
worker_id: i64,
worker_id: WorkerId,
kv_cache_event_bytes: &[u8],
) -> PyResult<()> {
let kv_cache_event: llm_rs::kv_router::protocols::KvCacheEvent =
......@@ -377,12 +392,12 @@ impl RadixTree {
Ok(())
}
fn remove_worker(&mut self, _py: Python, worker_id: i64) -> PyResult<()> {
fn remove_worker(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.remove_worker(worker_id);
Ok(())
}
fn clear_all_blocks(&mut self, _py: Python, worker_id: i64) -> PyResult<()> {
fn clear_all_blocks(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.clear_all_blocks(worker_id);
Ok(())
}
......@@ -517,16 +532,19 @@ impl ApproxKvIndexer {
})
}
#[pyo3(signature = (tokens, worker_id, dp_rank=0))]
fn process_routing_decision_for_request<'p>(
&self,
py: Python<'p>,
tokens: Vec<u32>,
worker_id: i64,
worker_id: WorkerId,
dp_rank: DpRank,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker = llm_rs::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank);
indexer
.process_routing_decision_for_request(tokens.as_slice(), worker_id)
.process_routing_decision_for_request(tokens.as_slice(), worker)
.await
.map_err(to_pyerr)?;
Ok(())
......@@ -538,7 +556,7 @@ impl ApproxKvIndexer {
#[derive(Clone)]
pub(crate) struct EndpointKvMetrics {
#[pyo3(get, set)]
pub worker_id: i64,
pub worker_id: WorkerId,
#[pyo3(get, set)]
pub request_active_slots: u64,
#[pyo3(get, set)]
......@@ -784,7 +802,7 @@ impl WorkerStats {
request_active_slots: u64,
request_total_slots: u64,
num_requests_waiting: u64,
data_parallel_rank: Option<u32>,
data_parallel_rank: Option<DpRank>,
) -> Self {
Self(RsWorkerStats {
data_parallel_rank,
......@@ -961,7 +979,7 @@ impl KvPushRouter {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, extra_args=None))]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None))]
fn generate<'p>(
&self,
py: Python<'p>,
......@@ -971,7 +989,8 @@ impl KvPushRouter {
sampling_options: Option<PyObject>,
output_options: Option<PyObject>,
router_config_override: Option<PyObject>,
worker_id: Option<i64>,
worker_id: Option<WorkerId>,
dp_rank: Option<DpRank>,
extra_args: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults
......@@ -1027,6 +1046,7 @@ impl KvPushRouter {
.sampling_options(sampling_options)
.output_options(output_options)
.router_config_override(router_config_override)
.dp_rank(dp_rank)
.extra_args(extra_args);
// Set backend_instance_id if worker_id is provided
......@@ -1053,6 +1073,43 @@ impl KvPushRouter {
Self::process_request_to_stream(py, self.inner.clone(), request)
}
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
fn best_worker<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
router_config_override: Option<PyObject>,
request_id: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
depythonize(obj.bind(py)).map_err(to_pyerr)?;
Ok::<_, PyErr>(Some(override_config))
})?
} else {
None
};
let chooser = self.inner.chooser.clone();
let update_states = request_id.is_some();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (best_worker, overlap_blocks) = chooser
.find_best_match(
request_id.as_deref(),
&token_ids,
router_config_override.as_ref(),
update_states,
)
.await
.map_err(to_pyerr)?;
Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
})
}
/// Deprecated: Use `best_worker()` instead which returns (worker_id, dp_rank, overlap_blocks)
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
fn best_worker_id<'p>(
&self,
......@@ -1061,6 +1118,16 @@ impl KvPushRouter {
router_config_override: Option<PyObject>,
request_id: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
// Issue deprecation warning
let warnings = py.import("warnings")?;
warnings.call_method1(
"warn",
(
"best_worker_id() is deprecated. Use best_worker() instead which returns (worker_id, dp_rank, overlap_blocks)",
py.get_type::<pyo3::exceptions::PyDeprecationWarning>(),
),
)?;
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
......@@ -1075,7 +1142,7 @@ impl KvPushRouter {
let update_states = request_id.is_some();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = chooser
let (best_worker, overlap_blocks) = chooser
.find_best_match(
request_id.as_deref(),
&token_ids,
......@@ -1085,8 +1152,8 @@ impl KvPushRouter {
.await
.map_err(to_pyerr)?;
// Return a tuple of (worker_id, overlap_blocks)
Ok((worker_id, overlap_blocks))
// Return only worker_id and overlap_blocks for backward compatibility
Ok((best_worker.worker_id, overlap_blocks))
})
}
......@@ -1130,6 +1197,7 @@ impl KvPushRouter {
.await
.map_err(to_pyerr)?;
// Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
// Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
Python::with_gil(|py| {
pythonize(py, &loads)
......
......@@ -44,6 +44,11 @@ impl ModelRuntimeConfig {
self.inner.reasoning_parser = reasoning_parser;
}
#[setter]
fn set_data_parallel_size(&mut self, data_parallel_size: u32) {
self.inner.data_parallel_size = data_parallel_size;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner
......
......@@ -778,16 +778,21 @@ class KvEventPublisher:
...
def __init__(
self, component: Component, worker_id: int, kv_block_size: int
self, component: Component, worker_id: int, kv_block_size: int, dp_rank: int = 0
) -> None:
"""
Create a `KvEventPublisher` object
Args:
component: The component to publish events for
worker_id: The worker ID
kv_block_size: The KV block size (must be > 0)
dp_rank: The data parallel rank (defaults to 0)
"""
def publish_stored(
self,
event_id,
int,
event_id: int,
token_ids: List[int],
num_block_tokens: List[int],
block_hashes: List[int],
......@@ -796,12 +801,24 @@ class KvEventPublisher:
) -> None:
"""
Publish a KV stored event.
Args:
event_id: The event ID
token_ids: List of token IDs
num_block_tokens: Number of tokens per block
block_hashes: List of block hashes (signed 64-bit integers)
lora_id: The LoRA ID
parent_hash: Optional parent hash (signed 64-bit integer)
"""
...
def publish_removed(self, event_id, int, block_hashes: List[int]) -> None:
def publish_removed(self, event_id: int, block_hashes: List[int]) -> None:
"""
Publish a KV removed event.
Args:
event_id: The event ID
block_hashes: List of block hashes to remove (signed 64-bit integers)
"""
...
......@@ -1199,6 +1216,7 @@ class KvPushRouter:
output_options: Optional[JsonLike] = None,
router_config_override: Optional[JsonLike] = None,
worker_id: Optional[int] = None,
dp_rank: Optional[int] = None,
) -> AsyncIterator[JsonLike]:
"""
Generate text using the KV-aware router.
......@@ -1213,6 +1231,10 @@ class KvPushRouter:
worker_id: Optional worker ID to route to directly. If set, the request
will be sent to this specific worker and router states will be
updated accordingly.
dp_rank: Optional data parallel rank to route to. If set along with worker_id,
the request will be routed to the specific (worker_id, dp_rank) pair.
If only dp_rank is set, the router will select the best worker but
force routing to the specified dp_rank.
Returns:
An async iterator yielding generation responses
......@@ -1220,10 +1242,36 @@ class KvPushRouter:
Note:
- If worker_id is set, the request bypasses KV matching and routes directly
to the specified worker while still updating router states.
- dp_rank allows targeting a specific data parallel replica when workers have
multiple replicas (data_parallel_size > 1).
- This is different from query_instance_id which doesn't route the request.
"""
...
async def best_worker(
self,
token_ids: List[int],
router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None,
) -> Tuple[int, int, int]:
"""
Find the best matching worker for the given tokens.
Args:
token_ids: List of token IDs to find matches for
router_config_override: Optional router configuration override
request_id: Optional request ID. If provided, router states will be updated
to track this request (active blocks, lifecycle events). If not
provided, this is a query-only operation that doesn't affect state.
Returns:
A tuple of (worker_id, dp_rank, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- dp_rank: The data parallel rank of the selected worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def best_worker_id(
self,
token_ids: List[int],
......@@ -1231,6 +1279,8 @@ class KvPushRouter:
request_id: Optional[str] = None,
) -> Tuple[int, int]:
"""
[DEPRECATED] Use best_worker() instead which returns (worker_id, dp_rank, overlap_blocks).
Find the best matching worker for the given tokens.
Args:
......@@ -1244,6 +1294,9 @@ class KvPushRouter:
A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found
.. deprecated::
Use :meth:`best_worker` instead which also returns dp_rank.
"""
...
......@@ -1260,8 +1313,13 @@ class KvPushRouter:
Returns:
A list of dictionaries, each containing:
- worker_id: The worker ID
- dp_rank: The data parallel rank
- potential_prefill_tokens: Number of tokens that would need prefill
- potential_decode_blocks: Number of blocks currently in decode phase
Note:
Each (worker_id, dp_rank) pair is returned as a separate entry.
If you need aggregated loads per worker_id, sum the values manually.
"""
...
......@@ -1287,7 +1345,7 @@ class KvPushRouter:
Note:
This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing.
`best_worker()` with `request_id` for custom routing.
"""
...
......@@ -1304,7 +1362,7 @@ class KvPushRouter:
Note:
This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing.
`best_worker()` with `request_id` for custom routing.
"""
...
......
......@@ -85,17 +85,21 @@ async def test_radix_tree_binding(distributed_runtime):
overlap_scores = radix_tree.find_matches([0])
# Verify the results
# Note: scores is now Dict[(worker_id, dp_rank), score]
assert overlap_scores.scores is not None
assert (
len(overlap_scores.scores) == 1
), f"Expected 1 worker in scores, got {len(overlap_scores.scores)}"
assert worker_id in overlap_scores.scores, f"Worker {worker_id} not found in scores"
worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert (
overlap_scores.scores[worker_id] == 1
), f"Expected score 1 for worker {worker_id}, got {overlap_scores.scores[worker_id]}"
worker_key in overlap_scores.scores
), f"Worker {worker_key} not found in scores"
assert (
overlap_scores.scores[worker_key] == 1
), f"Expected score 1 for worker {worker_key}, got {overlap_scores.scores[worker_key]}"
print(
f"✓ RadixTree test passed: worker {worker_id} has score {overlap_scores.scores[worker_id]}"
f"✓ RadixTree test passed: worker {worker_key} has score {overlap_scores.scores[worker_key]}"
)
......@@ -130,24 +134,25 @@ async def test_event_handler(distributed_runtime):
event_publisher.store_event(test_token, lora_id)
# wait for the event to be processed as it is sent asynchronously
# Retry loop for CI environments where processing may take longer
worker_key = (worker_id, 0) # (worker_id, dp_rank)
for retry in range(10): # Try up to 10 times
await asyncio.sleep(0.5) # Wait 500ms between retries
scores = await indexer.find_matches_for_request(test_token, lora_id)
if (
scores.scores
and worker_id in scores.scores
and scores.scores[worker_id] == 1
and worker_key in scores.scores
and scores.scores[worker_key] == 1
):
break
if retry == 9: # Last iteration
# Provide detailed error message for debugging
assert scores.scores, f"No scores found after {(retry+1)*0.5}s"
assert (
worker_id in scores.scores
), f"Worker {worker_id} not in scores after {(retry+1)*0.5}s"
worker_key in scores.scores
), f"Worker {worker_key} not in scores after {(retry+1)*0.5}s"
assert (
scores.scores[worker_id] == 1
), f"Expected score 1, got {scores.scores.get(worker_id)} after {(retry+1)*0.5}s"
scores.scores[worker_key] == 1
), f"Expected score 1, got {scores.scores.get(worker_key)} after {(retry+1)*0.5}s"
# remove event
event_publisher.remove_event()
......@@ -185,8 +190,9 @@ async def test_approx_kv_indexer(distributed_runtime):
scores = await indexer.find_matches_for_request(tokens)
assert scores.scores
assert worker_id in scores.scores
assert scores.scores[worker_id] == 2
worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert worker_key in scores.scores
assert scores.scores[worker_key] == 2
class EventPublisher:
......@@ -281,7 +287,7 @@ async def metrics_publisher_task(kv_listener, expected_metrics):
expected_metrics["request_active_slots"],
expected_metrics["request_total_slots"],
expected_metrics["num_requests_waiting"],
None,
0, # data_parallel_rank (0 = DP not enabled)
)
kv_stats = KvStats(
......
......@@ -38,7 +38,9 @@ use crate::{
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block,
},
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
protocols::{
LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
scoring::ProcessedEndpoints,
subscriber::start_kv_router_background,
......@@ -74,7 +76,7 @@ pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";
pub trait WorkerSelector {
fn select_worker(
&self,
workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>;
......@@ -316,7 +318,7 @@ impl KvRouter {
}
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking
pub async fn find_best_match(
&self,
......@@ -324,7 +326,7 @@ impl KvRouter {
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
) -> anyhow::Result<(i64, u32)> {
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
// Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true");
......@@ -350,7 +352,7 @@ impl KvRouter {
(false, false) => (None, None),
};
let best_worker_id = self
let best_worker = self
.scheduler
.schedule(
context_id.map(|s| s.to_string()),
......@@ -364,17 +366,17 @@ impl KvRouter {
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
.process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap())
.process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
.await
.unwrap();
};
let overlap_amount = overlap_scores
.scores
.get(&best_worker_id)
.get(&best_worker)
.copied()
.unwrap_or(0);
Ok((best_worker_id, overlap_amount))
Ok((best_worker, overlap_amount))
}
pub async fn add_request(
......@@ -382,7 +384,7 @@ impl KvRouter {
request_id: String,
tokens: &[u32],
overlap_blocks: u32,
worker_id: i64,
worker: WorkerWithDpRank,
) {
let isl_tokens = tokens.len();
......@@ -397,7 +399,7 @@ impl KvRouter {
maybe_seq_hashes,
isl_tokens,
overlap_blocks,
worker_id,
worker,
)
.await;
}
......@@ -450,12 +452,13 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
// Handle different request types
let response = match request {
RouterRequest::New { tokens } => {
let (worker_id, overlap_blocks) = self
let (best_worker, overlap_blocks) = self
.find_best_match(Some(&context_id), &tokens, None, true)
.await?;
RouterResponse::New {
worker_id,
worker_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
overlap_blocks,
}
}
......@@ -523,24 +526,45 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id");
let (instance_id, overlap_amount) = if let Some(id) = request.backend_instance_id {
// If instance_id is set, use it and manually add the request to track it
if !query_instance_id {
let (instance_id, dp_rank, overlap_amount) = if let Some(id) =
request.backend_instance_id
{
// If instance_id is set, use it and compute actual overlap
let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_id {
tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
);
}
// Compute actual overlap blocks by querying the indexer
let block_hashes =
compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);
self.chooser
.add_request(context_id.clone(), &request.token_ids, 0, id)
.add_request(
context_id.clone(),
&request.token_ids,
overlap_blocks,
worker,
)
.await;
}
(id, 0)
(id, dp_rank, overlap_blocks)
} else {
// Otherwise, find the best match
self.chooser
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(&context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id
)
.await?
.await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount)
};
// if request has the annotation "query_instance_id",
......@@ -564,6 +588,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
......
......@@ -27,11 +27,11 @@ use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::indexer::{
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent,
WorkerId, compute_block_hash_for_seq,
compute_block_hash_for_seq,
};
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
KvCacheStoredBlockData, LocalBlockHash, WorkerId, WorkerWithDpRank,
};
#[derive(Debug)]
......@@ -44,8 +44,8 @@ struct MatchRequest {
#[derive(Debug)]
struct RouterResult {
/// The id of the selected worker.
worker_id: WorkerId,
/// The worker (with dp_rank) that was selected.
worker: WorkerWithDpRank,
/// The local hashes of the tokens sent to the worker.
local_hashes: Vec<LocalBlockHash>,
......@@ -58,8 +58,8 @@ struct RouterResult {
struct TimerEntry {
/// The key of the timer.
key: ExternalSequenceBlockHash,
/// The worker id that stored this block.
worker: WorkerId,
/// The worker (with dp_rank) that stored this block.
worker: WorkerWithDpRank,
}
/// A data structure to manage a collection of timers, addressable by a key.
......@@ -237,10 +237,11 @@ impl ApproxKvIndexer {
event_id += 1;
let event = RouterEvent::new(
result.worker_id,
result.worker.worker_id,
KvCacheEvent {
event_id,
data: stored_event,
dp_rank: result.worker.dp_rank,
}
);
......@@ -248,7 +249,7 @@ impl ApproxKvIndexer {
timer_manager.insert(result.sequence_hashes.iter().map(|h| TimerEntry {
key: ExternalSequenceBlockHash(*h),
worker: result.worker_id,
worker: result.worker,
}).collect());
}
......@@ -269,12 +270,13 @@ impl ApproxKvIndexer {
event_id += 1;
let event = RouterEvent::new(
e.worker,
e.worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![e.key],
}),
dp_rank: e.worker.dp_rank,
}
);
......@@ -307,13 +309,13 @@ impl ApproxKvIndexer {
/// Core function to process a routing decision with pre-computed hashes
pub async fn process_routing_decision(
&self,
worker_id: WorkerId,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
self.route_tx
.send(RouterResult {
worker_id,
worker,
local_hashes,
sequence_hashes,
})
......@@ -327,7 +329,7 @@ impl ApproxKvIndexer {
pub async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker_id: WorkerId,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size);
......@@ -338,7 +340,7 @@ impl ApproxKvIndexer {
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.process_routing_decision(worker_id, local_hashes, sequence_hashes)
self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
}
......@@ -526,14 +528,20 @@ mod tests {
// 2. Inform indexer about routing decision
indexer
.process_routing_decision_for_request(&tokens, worker_id)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await
.unwrap();
// Poll until we observe the match being registered
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_id).copied() == Some(1)
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_id))
.copied()
== Some(1)
})
.await;
......@@ -554,14 +562,18 @@ mod tests {
let worker_id: WorkerId = 7;
indexer
.process_routing_decision_for_request(&tokens, worker_id)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await
.unwrap();
// Wait until the worker is registered
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.contains_key(&worker_id)
s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
})
.await;
......@@ -571,7 +583,8 @@ mod tests {
// Ensure the worker's entries are gone
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
!s.scores.contains_key(&worker_id)
!s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
})
.await;
}
......@@ -590,19 +603,31 @@ mod tests {
// Register on both workers
indexer
.process_routing_decision_for_request(&tokens, worker_0)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await
.unwrap();
indexer
.process_routing_decision_for_request(&tokens, worker_1)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await
.unwrap();
// Ensure both workers are registered
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_0).copied() == Some(1)
&& s.scores.get(&worker_1).copied() == Some(1)
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
== Some(1)
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
})
.await;
......@@ -612,7 +637,12 @@ mod tests {
// Confirm the removed worker is gone, and the other remains.
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
!s.scores.contains_key(&worker_0) && s.scores.get(&worker_1).copied() == Some(1)
!s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
})
.await;
}
......@@ -631,14 +661,20 @@ mod tests {
// Register Sequence A on worker A
indexer
.process_routing_decision_for_request(&seq_a, worker_a)
.process_routing_decision_for_request(
&seq_a,
WorkerWithDpRank::from_worker_id(worker_a),
)
.await
.unwrap();
// Ensure the indexer has registered the block
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&seq_a).await.unwrap();
s.scores.get(&worker_a).copied() == Some(1)
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a))
.copied()
== Some(1)
})
.await;
......@@ -649,7 +685,12 @@ mod tests {
let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap();
// Expect worker A to have an overlap score of 1 (shared first block)
assert_eq!(overlap.scores.get(&worker_a), Some(&1));
assert_eq!(
overlap
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a)),
Some(&1)
);
}
/// When the same block resides on multiple workers, all should appear in the overlap scores.
......@@ -666,25 +707,47 @@ mod tests {
// Register the same sequence on two different workers
indexer
.process_routing_decision_for_request(&tokens, worker_0)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await
.unwrap();
indexer
.process_routing_decision_for_request(&tokens, worker_1)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await
.unwrap();
// Wait until both workers are reflected in overlap scores
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_0).copied() == Some(1)
&& s.scores.get(&worker_1).copied() == Some(1)
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
== Some(1)
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
})
.await;
let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
assert_eq!(scores.scores.get(&worker_0), Some(&1));
assert_eq!(scores.scores.get(&worker_1), Some(&1));
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0)),
Some(&1)
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1)),
Some(&1)
);
}
}
This diff is collapsed.
......@@ -5,6 +5,34 @@ use crate::tokens::{SequenceHash, Token};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A worker identifier.
pub type WorkerId = i64;
/// A data parallel rank identifier.
pub type DpRank = u32;
/// A worker identifier combined with its data parallel rank.
/// Used for routing decisions in data parallel setups.
/// dp_rank = 0 indicates either DP not enabled or the first rank.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct WorkerWithDpRank {
pub worker_id: WorkerId,
pub dp_rank: DpRank,
}
impl WorkerWithDpRank {
pub fn new(worker_id: WorkerId, dp_rank: DpRank) -> Self {
Self { worker_id, dp_rank }
}
pub fn from_worker_id(worker_id: WorkerId) -> Self {
Self {
worker_id,
dp_rank: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterRequest {
......@@ -26,15 +54,24 @@ impl Default for RouterRequest {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterResponse {
New { worker_id: i64, overlap_blocks: u32 },
PrefillMarked { success: bool },
FreeMarked { success: bool },
New {
worker_id: WorkerId,
#[serde(default)]
dp_rank: DpRank,
overlap_blocks: u32,
},
PrefillMarked {
success: bool,
},
FreeMarked {
success: bool,
},
}
#[derive(Debug)]
pub struct WorkerSelectionResult {
/// The worker id of the selected worker
pub worker_id: i64,
/// The full worker information including dp_rank
pub worker: WorkerWithDpRank,
/// The total number of blocks required to prefill the request
pub required_blocks: u64,
......@@ -54,7 +91,7 @@ pub struct ForwardPassMetrics {
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct WorkerStats {
// https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
pub data_parallel_rank: Option<u32>,
pub data_parallel_rank: Option<DpRank>,
pub request_active_slots: u64,
pub request_total_slots: u64,
pub num_requests_waiting: u64,
......@@ -136,7 +173,7 @@ impl From<i64> for ExternalSequenceBlockHash {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillEvent {
pub request_id: String,
pub worker_id: i64,
pub worker_id: WorkerId,
pub data: PrefillEventData,
pub router_id: Uuid,
}
......@@ -155,7 +192,7 @@ pub enum PrefillEventData {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ActiveSequenceEvent {
pub request_id: String,
pub worker_id: i64,
pub worker: WorkerWithDpRank,
pub data: ActiveSequenceEventData,
pub router_id: Uuid,
}
......@@ -199,6 +236,9 @@ pub struct KvCacheEvent {
pub event_id: u64,
/// The data associated with the event.
pub data: KvCacheEventData,
/// The data parallel rank of the worker emitting this event (0 if DP not enabled).
#[serde(default)]
pub dp_rank: DpRank,
}
/// Represents the data associated with a cache event.
......@@ -313,6 +353,7 @@ mod tests {
let event = KvCacheEvent {
event_id: 1,
data: event_data,
dp_rank: 0,
};
let events = KvCacheEvents {
......
......@@ -326,13 +326,16 @@ pub async fn start_zmq_listener(
};
tracing::trace!(
"ZMQ listener on {} received batch with {} events (seq={})",
"ZMQ listener on {} received batch with {} events (seq={}, dp_rank={})",
zmq_endpoint,
batch.events.len(),
seq
seq,
batch.data_parallel_rank
);
let dp_rank = batch.data_parallel_rank;
for raw_event in batch.events.into_iter() {
let event = convert_event(raw_event, seq, kv_block_size, &warning_count);
let event = convert_event(raw_event, seq, kv_block_size, dp_rank, &warning_count);
if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped";
......@@ -356,6 +359,7 @@ fn convert_event(
raw: RawKvEvent,
event_id: u64,
kv_block_size: u32,
dp_rank: u32,
warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent {
match raw {
......@@ -387,6 +391,7 @@ fn convert_event(
warning_count,
),
}),
dp_rank,
}
}
RawKvEvent::BlockRemoved { block_hashes, .. } => {
......@@ -400,11 +405,13 @@ fn convert_event(
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes,
}),
dp_rank,
}
}
RawKvEvent::AllBlocksCleared => KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
dp_rank,
},
}
}
......@@ -1014,7 +1021,7 @@ mod test_event_processing {
medium: None,
};
let out = convert_event(raw_evt, 42, kv_block_size, &Arc::new(AtomicU32::new(0)));
let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Stored(_)));
}
......@@ -1025,7 +1032,7 @@ mod test_event_processing {
block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)],
medium: None,
};
let out = convert_event(raw_evt, 7, kv_block_size, &Arc::new(AtomicU32::new(0)));
let out = convert_event(raw_evt, 7, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Removed(_)));
}
......@@ -1034,7 +1041,7 @@ mod test_event_processing {
fn test_convert_event_all_blocks_cleared() {
let kv_block_size = 4;
let raw_evt = RawKvEvent::AllBlocksCleared;
let out = convert_event(raw_evt, 1, kv_block_size, &Arc::new(AtomicU32::new(0)));
let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Cleared));
}
}
......@@ -1115,6 +1122,7 @@ mod tests_startup_helpers {
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)],
}),
dp_rank: 0,
};
let token = CancellationToken::new();
......
......@@ -12,7 +12,6 @@ mod tests {
use super::*;
use crate::kv_router::indexer::KvIndexer;
use crate::kv_router::indexer::KvIndexerMetrics;
use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::*;
use std::time::Duration;
use tempfile::tempdir;
......@@ -50,6 +49,7 @@ mod tests {
KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
)
}
......@@ -65,6 +65,7 @@ mod tests {
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
)
}
......
......@@ -18,21 +18,24 @@ use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::indexer::OverlapScores;
use super::protocols::WorkerSelectionResult;
use super::protocols::{DpRank, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::sequence::ActiveSequencesMultiWorker;
use crate::tokens::SequenceHash;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
pub worker_id: i64,
pub worker_id: WorkerId,
#[serde(default)]
pub dp_rank: DpRank,
pub isl_blocks: usize,
pub overlap_blocks: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
pub worker_id: i64,
pub worker_id: WorkerId,
pub dp_rank: DpRank,
pub potential_prefill_tokens: usize,
pub potential_decode_blocks: usize,
}
......@@ -51,7 +54,7 @@ pub enum KvSchedulerError {
#[derive(Debug)]
pub struct SchedulingResponse {
pub best_worker_id: i64,
pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32,
}
......@@ -60,8 +63,8 @@ pub struct SchedulingRequest {
pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize,
pub overlaps: OverlapScores,
pub decode_blocks: HashMap<i64, usize>,
pub prefill_tokens: HashMap<i64, usize>,
pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
// Router config overrides for this specific request
pub router_config_override: Option<RouterConfigOverride>,
// Whether to update scheduler states (false for query_instance_id requests)
......@@ -94,17 +97,18 @@ impl KvScheduler {
component: Component,
block_size: u32,
instances_rx: watch::Receiver<Vec<Instance>>,
runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>,
runtime_configs_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone();
let runtime_configs: HashMap<i64, ModelRuntimeConfig> = runtime_configs_rx.borrow().clone();
let runtime_configs: HashMap<WorkerId, ModelRuntimeConfig> =
runtime_configs_rx.borrow().clone();
// Create shared workers_with_configs wrapped in Arc<RwLock>
let workers_with_configs: Arc<RwLock<HashMap<i64, Option<ModelRuntimeConfig>>>> = {
let workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>> = {
let mut initial_map = HashMap::new();
for instance in &instances {
let worker_id = instance.instance_id;
......@@ -117,14 +121,10 @@ impl KvScheduler {
Arc::new(RwLock::new(initial_map))
};
let worker_ids: Vec<i64> = instances
.iter()
.map(|instance| instance.instance_id)
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(),
block_size as usize,
worker_ids,
workers_with_configs.read().await.clone(), // this includes dp_size info
replica_sync,
router_uuid,
));
......@@ -162,24 +162,23 @@ impl KvScheduler {
let new_instances = instances_monitor_rx.borrow_and_update().clone();
let new_configs = configs_monitor_rx.borrow_and_update().clone();
// Update workers when instances change
let worker_ids: Vec<i64> = new_instances
.iter()
.map(|instance| instance.instance_id)
.collect();
slots_monitor.update_workers(worker_ids);
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
workers_map.clear();
// Build the new workers_with_configs map
let mut new_workers_with_configs = HashMap::new();
for instance in &new_instances {
let worker_id = instance.instance_id;
let config = new_configs.get(&worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
workers_map.insert(worker_id, config);
new_workers_with_configs.insert(worker_id, config);
}
// Update workers when instances change
slots_monitor.update_workers(new_workers_with_configs.clone());
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
*workers_map = new_workers_with_configs;
tracing::trace!(
"Updated workers_with_configs with {} workers",
workers_map.len()
......@@ -229,7 +228,8 @@ impl KvScheduler {
match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => {
let event = KVHitRateEvent {
worker_id: selection.worker_id,
worker_id: selection.worker.worker_id,
dp_rank: selection.worker.dp_rank,
isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks,
};
......@@ -238,7 +238,7 @@ impl KvScheduler {
}
let response = SchedulingResponse {
best_worker_id: selection.worker_id,
best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks,
};
request.respond(response);
......@@ -261,7 +261,7 @@ impl KvScheduler {
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
selection.worker,
)
.await
{
......@@ -302,7 +302,7 @@ impl KvScheduler {
overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
) -> Result<i64, KvSchedulerError> {
) -> Result<WorkerWithDpRank, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
......@@ -324,8 +324,7 @@ impl KvScheduler {
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
let best_worker_id = response.best_worker_id;
Ok(best_worker_id)
Ok(response.best_worker)
}
pub async fn add_request(
......@@ -334,11 +333,11 @@ impl KvScheduler {
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
worker_id: i64,
worker: WorkerWithDpRank,
) {
let _ = self
.slots
.add_request(request_id, token_sequence, isl, overlap, worker_id)
.add_request(request_id, token_sequence, isl, overlap, worker)
.await;
}
......@@ -363,21 +362,22 @@ impl KvScheduler {
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
.await;
// Get all unique worker IDs from both hashmaps
let mut worker_ids: HashSet<i64> = HashSet::new();
worker_ids.extend(decode_blocks.keys().copied());
worker_ids.extend(prefill_tokens.keys().copied());
// Get all unique WorkerWithDpRank from both hashmaps
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
// Create PotentialLoad for each worker
let mut loads = Vec::new();
for worker_id in worker_ids {
for worker in workers {
loads.push(PotentialLoad {
worker_id,
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens
.get(&worker_id)
.get(&worker)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker_id).copied().unwrap_or(0),
potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
});
}
......@@ -386,7 +386,7 @@ impl KvScheduler {
}
// Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) -> WorkerWithDpRank {
if logits.is_empty() {
panic!("Empty logits for softmax sampling");
}
......@@ -474,7 +474,7 @@ impl DefaultWorkerSelector {
impl WorkerSelector for DefaultWorkerSelector {
fn select_worker(
&self,
workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
workers: &HashMap<WorkerId, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
......@@ -494,17 +494,28 @@ impl WorkerSelector for DefaultWorkerSelector {
let mut worker_logits = HashMap::new();
let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker
for worker_id in workers.keys() {
let overlap = *overlaps.get(worker_id).unwrap_or(&0);
// Calculate logits for each worker with dp_rank
// Outer loop: iterate over all workers from runtime config
// Inner loop: iterate over all dp_ranks for each worker
for (worker_id, config) in workers.iter() {
// Get data_parallel_size from runtime config
// data_parallel_size defaults to 1 in ModelRuntimeConfig
let data_parallel_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1); // Fallback if config is None
// Iterate over all dp_ranks for this worker
for dp_rank in 0..data_parallel_size {
let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
// Get overlap for this worker (defaults to 0 if not in overlaps)
let overlap = *overlaps.get(&worker).unwrap_or(&0);
// this is the number of prefill tokens the worker would have if the request were scheduled there
let prefill_token = *prefill_tokens.get(worker_id).unwrap_or(&isl);
let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
// this is the number of decode blocks the worker would have if the request were scheduled there
let decode_block = *decode_blocks
.get(worker_id)
.get(&worker)
.unwrap_or(&(potential_prefill_block.floor() as usize))
as f64;
......@@ -519,14 +530,17 @@ impl WorkerSelector for DefaultWorkerSelector {
let logit = overlap_weight * potential_prefill_block + decode_block;
max_logit = max_logit.max(logit);
worker_logits.insert(*worker_id, logit);
worker_logits.insert(worker, logit);
tracing::info!(
"Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \
"Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}"
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
worker.dp_rank
);
}
}
// Use softmax sampling to select worker
// Use override if provided, otherwise use default config
......@@ -535,29 +549,32 @@ impl WorkerSelector for DefaultWorkerSelector {
.as_ref()
.and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature);
let best_worker_id = softmax_sample(&worker_logits, temperature);
let best_logit = worker_logits[&best_worker_id];
let best_worker = softmax_sample(&worker_logits, temperature);
let best_logit = worker_logits[&best_worker];
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
let best_overlap = *overlaps.get(&best_worker_id).unwrap_or(&0);
// this is a runtime config set on a per worker basis, not per dp-rank
let total_blocks_info = workers
.get(&best_worker_id)
.get(&best_worker.worker_id)
.and_then(|cfg| cfg.as_ref())
.and_then(|cfg| cfg.total_kv_blocks)
.map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default();
tracing::info!(
"Selected worker: {}, logit: {:.3}, cached blocks: {}{}",
best_worker_id,
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}{}",
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
best_overlap,
total_blocks_info
);
Ok(WorkerSelectionResult {
worker_id: best_worker_id,
worker: best_worker,
required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0),
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
})
}
}
......@@ -570,54 +587,61 @@ mod tests {
fn test_softmax_sample_single_key() {
// Test that with a single key, softmax_sample always returns that key
let mut logits = HashMap::new();
let worker_id = 42;
logits.insert(worker_id, 0.5); // The value doesn't matter
let worker = WorkerWithDpRank::from_worker_id(42);
logits.insert(worker, 0.5); // The value doesn't matter
// Test with different temperatures
for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature);
assert_eq!(result, worker_id, "Should return the only available worker");
assert_eq!(result, worker, "Should return the only available worker");
}
// Test with different logit values
logits.clear();
logits.insert(worker_id, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.insert(worker, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker);
logits.clear();
logits.insert(worker_id, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.insert(worker, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker);
logits.clear();
logits.insert(worker_id, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.insert(worker, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker);
}
#[test]
fn test_softmax_sample_zero_temperature() {
// Test that with temperature 0, softmax_sample returns the key with smallest logit
let mut logits = HashMap::new();
logits.insert(1, 5.0);
logits.insert(2, 3.0); // This has the smallest logit
logits.insert(3, 7.0);
logits.insert(4, 3.5);
let worker1 = WorkerWithDpRank::from_worker_id(1);
let worker2 = WorkerWithDpRank::from_worker_id(2);
let worker3 = WorkerWithDpRank::from_worker_id(3);
let worker4 = WorkerWithDpRank::from_worker_id(4);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); // This has the smallest logit
logits.insert(worker3, 7.0);
logits.insert(worker4, 3.5);
// With temperature 0, should always return worker 2 (smallest logit)
for _ in 0..10 {
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result, 2,
result, worker2,
"Should return worker with smallest logit when temperature is 0"
);
}
// Test with negative values
logits.clear();
logits.insert(10, -1.0);
logits.insert(20, -5.0); // This has the smallest logit
logits.insert(30, 0.0);
let worker10 = WorkerWithDpRank::from_worker_id(10);
let worker20 = WorkerWithDpRank::from_worker_id(20);
let worker30 = WorkerWithDpRank::from_worker_id(30);
logits.insert(worker10, -1.0);
logits.insert(worker20, -5.0); // This has the smallest logit
logits.insert(worker30, 0.0);
let result = softmax_sample(&logits, 0.0);
assert_eq!(result, 20, "Should handle negative logits correctly");
assert_eq!(result, worker20, "Should handle negative logits correctly");
}
}
This diff is collapsed.
......@@ -23,7 +23,8 @@ use crate::{
kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK,
ROUTER_SNAPSHOT_LOCK,
indexer::{DumpRequest, GetWorkersRequest, RouterEvent, WorkerId},
indexer::{DumpRequest, GetWorkersRequest, RouterEvent},
protocols::WorkerId,
},
};
......
......@@ -211,11 +211,26 @@ impl LocalModelBuilder {
.map(RequestTemplate::load)
.transpose()?;
// Override runtime configs with mocker engine args (applies to both paths)
if self.is_mocker
&& let Some(path) = &self.extra_engine_args
{
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.kv_cache_block_size = mocker_engine_args.block_size as u32;
self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
}
// frontend and echo engine don't need a path.
if self.model_path.is_none() {
let mut card = ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
);
card.kv_cache_block_size = self.kv_cache_block_size;
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone();
......@@ -266,18 +281,6 @@ impl LocalModelBuilder {
card.context_length = context_length;
}
// Override runtime configs with mocker engine args
if self.is_mocker
&& let Some(path) = &self.extra_engine_args
{
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
}
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment