Unverified Commit f978f4d1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: dp rank routing (#3597)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 29f5b822
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
NUM_WORKERS=8 NUM_WORKERS=8
MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B" MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
TENSOR_PARALLEL_SIZE=1 TENSOR_PARALLEL_SIZE=1
DATA_PARALLEL_SIZE=1
USE_MOCKERS=false USE_MOCKERS=false
USE_TRTLLM=false USE_TRTLLM=false
MODE="agg" # Options: agg (default), decode, prefill MODE="agg" # Options: agg (default), decode, prefill
...@@ -28,6 +29,10 @@ while [[ $# -gt 0 ]]; do ...@@ -28,6 +29,10 @@ while [[ $# -gt 0 ]]; do
TENSOR_PARALLEL_SIZE="$2" TENSOR_PARALLEL_SIZE="$2"
shift 2 shift 2
;; ;;
--data-parallel-size)
DATA_PARALLEL_SIZE="$2"
shift 2
;;
--mockers) --mockers)
USE_MOCKERS=true USE_MOCKERS=true
shift shift
...@@ -114,13 +119,19 @@ if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt ...@@ -114,13 +119,19 @@ if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt
exit 1 exit 1
fi fi
if ! [[ "$DATA_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$DATA_PARALLEL_SIZE" -lt 1 ]; then
echo "Error: DATA_PARALLEL_SIZE must be a positive integer"
exit 1
fi
if ! [[ "$BASE_GPU_OFFSET" =~ ^[0-9]+$ ]]; then if ! [[ "$BASE_GPU_OFFSET" =~ ^[0-9]+$ ]]; then
echo "Error: BASE_GPU_OFFSET must be a non-negative integer" echo "Error: BASE_GPU_OFFSET must be a non-negative integer"
exit 1 exit 1
fi fi
# Calculate total GPUs needed # Calculate total GPUs needed (TP * DP per worker)
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE)) GPUS_PER_WORKER=$((TENSOR_PARALLEL_SIZE * DATA_PARALLEL_SIZE))
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * GPUS_PER_WORKER))
LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1)) LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1))
echo "Configuration:" echo "Configuration:"
if [ "$USE_MOCKERS" = true ]; then if [ "$USE_MOCKERS" = true ]; then
...@@ -135,6 +146,8 @@ echo " Mode: $MODE" ...@@ -135,6 +146,8 @@ echo " Mode: $MODE"
echo " Workers: $NUM_WORKERS" echo " Workers: $NUM_WORKERS"
echo " Model: $MODEL_PATH" echo " Model: $MODEL_PATH"
echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE" echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
echo " Data Parallel Size: $DATA_PARALLEL_SIZE"
echo " GPUs per worker: $GPUS_PER_WORKER"
echo " Total GPUs needed: $TOTAL_GPUS_NEEDED" echo " Total GPUs needed: $TOTAL_GPUS_NEEDED"
echo " GPU Range: $BASE_GPU_OFFSET-$LAST_GPU" echo " GPU Range: $BASE_GPU_OFFSET-$LAST_GPU"
echo " Engine args: ${EXTRA_ARGS[*]}" echo " Engine args: ${EXTRA_ARGS[*]}"
...@@ -155,14 +168,16 @@ echo "Starting $NUM_WORKERS $MODE workers..." ...@@ -155,14 +168,16 @@ echo "Starting $NUM_WORKERS $MODE workers..."
for i in $(seq 1 $NUM_WORKERS); do for i in $(seq 1 $NUM_WORKERS); do
{ {
echo "[${MODE^} Worker-$i] Starting..." MODE_CAPITALIZED=$(echo "$MODE" | sed 's/\(.\)/\U\1/')
echo "[$MODE_CAPITALIZED Worker-$i] Starting..."
# Calculate GPU indices for this worker (with base offset) # Calculate GPU indices for this worker (with base offset)
START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * TENSOR_PARALLEL_SIZE )) # Each worker needs TP * DP GPUs
END_GPU=$(( START_GPU + TENSOR_PARALLEL_SIZE - 1 )) START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * GPUS_PER_WORKER ))
END_GPU=$(( START_GPU + GPUS_PER_WORKER - 1 ))
# Build CUDA_VISIBLE_DEVICES string # Build CUDA_VISIBLE_DEVICES string for all GPUs (TP * DP)
if [ "$TENSOR_PARALLEL_SIZE" -eq 1 ]; then if [ "$GPUS_PER_WORKER" -eq 1 ]; then
GPU_DEVICES="$START_GPU" GPU_DEVICES="$START_GPU"
else else
GPU_DEVICES="" GPU_DEVICES=""
...@@ -177,12 +192,17 @@ for i in $(seq 1 $NUM_WORKERS); do ...@@ -177,12 +192,17 @@ for i in $(seq 1 $NUM_WORKERS); do
if [ "$USE_MOCKERS" = true ]; then if [ "$USE_MOCKERS" = true ]; then
# Run mocker engine (no GPU assignment needed) # Run mocker engine (no GPU assignment needed)
exec python -m dynamo.mocker \ MOCKER_ARGS=()
--model-path "$MODEL_PATH" \ MOCKER_ARGS+=("--model-path" "$MODEL_PATH")
--endpoint dyn://test.mocker.generate \ MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
"${EXTRA_ARGS[@]}" if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
MOCKER_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
MOCKER_ARGS+=("${EXTRA_ARGS[@]}")
exec python -m dynamo.mocker "${MOCKER_ARGS[@]}"
elif [ "$USE_TRTLLM" = true ]; then elif [ "$USE_TRTLLM" = true ]; then
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES" echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
# Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization # Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization
TRTLLM_ARGS=() TRTLLM_ARGS=()
TRTLLM_ARGS+=("--model-path" "$MODEL_PATH") TRTLLM_ARGS+=("--model-path" "$MODEL_PATH")
...@@ -195,11 +215,14 @@ for i in $(seq 1 $NUM_WORKERS); do ...@@ -195,11 +215,14 @@ for i in $(seq 1 $NUM_WORKERS); do
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \ exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \
"${TRTLLM_ARGS[@]}" "${TRTLLM_ARGS[@]}"
else else
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES" echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
# Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing # Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing
VLLM_ARGS=() VLLM_ARGS=()
VLLM_ARGS+=("--model" "$MODEL_PATH") VLLM_ARGS+=("--model" "$MODEL_PATH")
VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE") VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
VLLM_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
if [ "$MODE" = "prefill" ]; then if [ "$MODE" = "prefill" ]; then
VLLM_ARGS+=("--is-prefill-worker") VLLM_ARGS+=("--is-prefill-worker")
fi fi
......
...@@ -99,6 +99,7 @@ class StandaloneRouterHandler: ...@@ -99,6 +99,7 @@ class StandaloneRouterHandler:
"eos_token_ids": request.get("eos_token_ids", []), "eos_token_ids": request.get("eos_token_ids", []),
"annotations": request.get("annotations", []), "annotations": request.get("annotations", []),
"disaggregated_params": request.get("disaggregated_params"), "disaggregated_params": request.get("disaggregated_params"),
"dp_rank": request.get("dp_rank"),
"extra_args": request.get("extra_args", {}), "extra_args": request.get("extra_args", {}),
} }
......
...@@ -33,7 +33,7 @@ class BaseWorkerHandler(ABC): ...@@ -33,7 +33,7 @@ class BaseWorkerHandler(ABC):
self.component = component self.component = component
self.engine_client = engine self.engine_client = engine
self.default_sampling_params = default_sampling_params self.default_sampling_params = default_sampling_params
self.kv_publisher = None self.kv_publishers = None
self.engine_monitor = VllmEngineMonitor(runtime, engine) self.engine_monitor = VllmEngineMonitor(runtime, engine)
@abstractmethod @abstractmethod
...@@ -81,9 +81,16 @@ class BaseWorkerHandler(ABC): ...@@ -81,9 +81,16 @@ class BaseWorkerHandler(ABC):
"""Override in subclasses if cleanup is needed.""" """Override in subclasses if cleanup is needed."""
pass pass
async def generate_tokens(self, prompt, sampling_params, request_id): async def generate_tokens(
self, prompt, sampling_params, request_id, data_parallel_rank=None
):
try: try:
gen = self.engine_client.generate(prompt, sampling_params, request_id) gen = self.engine_client.generate(
prompt,
sampling_params,
request_id,
data_parallel_rank=data_parallel_rank,
)
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
try: try:
...@@ -211,10 +218,12 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -211,10 +218,12 @@ class DecodeWorkerHandler(BaseWorkerHandler):
return return
logger.warning(f"Prefill error: {e}, falling back to local prefill") logger.warning(f"Prefill error: {e}, falling back to local prefill")
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id): async with self._abort_monitor(context, request_id):
try: try:
async for tok in self.generate_tokens( async for tok in self.generate_tokens(
prompt, sampling_params, request_id prompt, sampling_params, request_id, data_parallel_rank=dp_rank
): ):
yield tok yield tok
except EngineDeadError as e: except EngineDeadError as e:
...@@ -241,9 +250,13 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -241,9 +250,13 @@ class PrefillWorkerHandler(BaseWorkerHandler):
sampling_params_dict = extra_args.get("sampling_params", {}) sampling_params_dict = extra_args.get("sampling_params", {})
sampling_params = msgspec.convert(sampling_params_dict, SamplingParams) sampling_params = msgspec.convert(sampling_params_dict, SamplingParams)
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id, is_prefill=True): async with self._abort_monitor(context, request_id, is_prefill=True):
try: try:
gen = self.engine_client.generate(prompt, sampling_params, request_id) gen = self.engine_client.generate(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
)
except EngineDeadError as e: except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}") logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.") logger.warning("Initiating Dynamo Runtime shutdown.")
......
...@@ -5,7 +5,6 @@ import asyncio ...@@ -5,7 +5,6 @@ import asyncio
import logging import logging
import os import os
import signal import signal
from typing import Optional
import uvloop import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher from vllm.distributed.kv_events import ZmqEventPublisher
...@@ -107,21 +106,26 @@ def setup_kv_event_publisher( ...@@ -107,21 +106,26 @@ def setup_kv_event_publisher(
component, component,
generate_endpoint, generate_endpoint,
vllm_config, vllm_config,
) -> Optional[ZmqKvEventPublisher]: ):
""" """
Set up KV event publisher for prefix caching if enabled. Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Returns: Returns:
ZmqKvEventPublisher if prefix caching is enabled, None otherwise. List of ZmqKvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
""" """
if not config.engine_args.enable_prefix_caching: if not config.engine_args.enable_prefix_caching:
return None return None
# TODO: We start off with a valid endpoint, then we increment it by dp_rank # Get data_parallel_size to create publishers for all dp_ranks
# May no longer be valid. Lets remove the increment behavior from vLLM and here data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
kv_publishers = []
for dp_rank in range(data_parallel_size):
# Each dp_rank publishes to a different port
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint, config.engine_args.kv_events_config.endpoint,
data_parallel_rank=config.engine_args.data_parallel_rank or 0, data_parallel_rank=dp_rank,
).replace("*", "127.0.0.1") ).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig( zmq_config = ZmqKvEventPublisherConfig(
...@@ -130,10 +134,13 @@ def setup_kv_event_publisher( ...@@ -130,10 +134,13 @@ def setup_kv_event_publisher(
zmq_endpoint=zmq_endpoint, zmq_endpoint=zmq_endpoint,
) )
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config) kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
kv_publishers.append(kv_publisher)
logger.info(f"Worker reading KV events from {zmq_endpoint}") logger.info(
f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
)
return kv_publisher return kv_publishers if kv_publishers else None
def setup_vllm_engine(config, stat_logger=None): def setup_vllm_engine(config, stat_logger=None):
...@@ -200,12 +207,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -200,12 +207,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
runtime, component, engine_client, default_sampling_params runtime, component, engine_client, default_sampling_params
) )
# Set up KV event publisher for prefix caching if enabled # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
kv_publisher = setup_kv_event_publisher( kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config config, component, generate_endpoint, vllm_config
) )
if kv_publisher: if kv_publishers:
handler.kv_publisher = kv_publisher handler.kv_publishers = kv_publishers
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict() health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
...@@ -285,12 +292,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -285,12 +292,12 @@ async def init(runtime: DistributedRuntime, config: Config):
prefill_router_client, prefill_router_client,
) )
# Set up KV event publisher for prefix caching if enabled # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
kv_publisher = setup_kv_event_publisher( kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config config, component, generate_endpoint, vllm_config
) )
if kv_publisher: if kv_publishers:
handler.kv_publisher = kv_publisher handler.kv_publishers = kv_publishers
if config.engine_args.disable_log_stats is False: if config.engine_args.disable_log_stats is False:
from prometheus_client import REGISTRY from prometheus_client import REGISTRY
...@@ -311,6 +318,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -311,6 +318,12 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.tool_call_parser = config.tool_call_parser runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser runtime_config.reasoning_parser = config.reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(
vllm_config.parallel_config, "data_parallel_size", 1
)
runtime_config.data_parallel_size = data_parallel_size
await register_llm( await register_llm(
ModelInput.Tokens, ModelInput.Tokens,
ModelType.Chat | ModelType.Completions, ModelType.Chat | ModelType.Completions,
......
...@@ -322,7 +322,11 @@ The `KvPushRouter` provides the following methods: ...@@ -322,7 +322,11 @@ The `KvPushRouter` provides the following methods:
- **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management. - **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.
- **`best_worker_id(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, overlap_blocks)`. - **`best_worker(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, dp_rank, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
- **`best_worker_id(token_ids, router_config_override=None, request_id=None)`**: **[DEPRECATED - use `best_worker()` instead]** Query which worker would be selected for given tokens. Returns `(worker_id, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state - Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking - With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
......
...@@ -207,6 +207,7 @@ fn kv_event_create_stored_from_parts( ...@@ -207,6 +207,7 @@ fn kv_event_create_stored_from_parts(
parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash), parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
}), }),
event_id: kv_params.event_id, event_id: kv_params.event_id,
dp_rank: 0,
} }
} }
...@@ -224,6 +225,7 @@ fn kv_event_create_removed_from_parts( ...@@ -224,6 +225,7 @@ fn kv_event_create_removed_from_parts(
KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: 0,
} }
} }
......
...@@ -101,7 +101,7 @@ impl WorkerMetricsPublisher { ...@@ -101,7 +101,7 @@ impl WorkerMetricsPublisher {
#[derive(Clone)] #[derive(Clone)]
pub struct ZmqKvEventPublisherConfig { pub struct ZmqKvEventPublisherConfig {
#[pyo3(get, set)] #[pyo3(get, set)]
pub worker_id: i64, pub worker_id: WorkerId,
#[pyo3(get, set)] #[pyo3(get, set)]
pub kv_block_size: usize, pub kv_block_size: usize,
#[pyo3(get, set)] #[pyo3(get, set)]
...@@ -120,7 +120,7 @@ impl ZmqKvEventPublisherConfig { ...@@ -120,7 +120,7 @@ impl ZmqKvEventPublisherConfig {
zmq_topic = "".to_string() zmq_topic = "".to_string()
))] ))]
pub fn new( pub fn new(
worker_id: i64, worker_id: WorkerId,
kv_block_size: usize, kv_block_size: usize,
zmq_endpoint: String, zmq_endpoint: String,
zmq_topic: String, zmq_topic: String,
...@@ -234,13 +234,20 @@ impl Drop for ZmqKvEventListener { ...@@ -234,13 +234,20 @@ impl Drop for ZmqKvEventListener {
pub(crate) struct KvEventPublisher { pub(crate) struct KvEventPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>, inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
kv_block_size: usize, kv_block_size: usize,
dp_rank: DpRank,
warning_count: Arc<AtomicU32>, warning_count: Arc<AtomicU32>,
} }
#[pymethods] #[pymethods]
impl KvEventPublisher { impl KvEventPublisher {
#[new] #[new]
fn new(component: Component, worker_id: i64, kv_block_size: usize) -> PyResult<Self> { #[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))]
fn new(
component: Component,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
) -> PyResult<Self> {
if kv_block_size == 0 { if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
} }
...@@ -256,6 +263,7 @@ impl KvEventPublisher { ...@@ -256,6 +263,7 @@ impl KvEventPublisher {
Ok(Self { Ok(Self {
inner: inner.into(), inner: inner.into(),
kv_block_size, kv_block_size,
dp_rank,
warning_count: Arc::new(AtomicU32::new(0)), warning_count: Arc::new(AtomicU32::new(0)),
}) })
} }
...@@ -286,6 +294,7 @@ impl KvEventPublisher { ...@@ -286,6 +294,7 @@ impl KvEventPublisher {
&self.warning_count, &self.warning_count,
), ),
}), }),
dp_rank: self.dp_rank,
}; };
self.inner.publish(event).map_err(to_pyerr) self.inner.publish(event).map_err(to_pyerr)
...@@ -299,6 +308,7 @@ impl KvEventPublisher { ...@@ -299,6 +308,7 @@ impl KvEventPublisher {
let event = KvCacheEvent { let event = KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: self.dp_rank,
}; };
self.inner.publish(event).map_err(to_pyerr) self.inner.publish(event).map_err(to_pyerr)
...@@ -314,8 +324,13 @@ pub(crate) struct OverlapScores { ...@@ -314,8 +324,13 @@ pub(crate) struct OverlapScores {
#[pymethods] #[pymethods]
impl OverlapScores { impl OverlapScores {
#[getter] #[getter]
fn scores(&self) -> HashMap<llm_rs::kv_router::indexer::WorkerId, u32> { fn scores(&self) -> HashMap<(i64, u32), u32> {
self.inner.scores.clone() // Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
self.inner
.scores
.iter()
.map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
.collect()
} }
#[getter] #[getter]
...@@ -361,7 +376,7 @@ impl RadixTree { ...@@ -361,7 +376,7 @@ impl RadixTree {
fn apply_event( fn apply_event(
&mut self, &mut self,
_py: Python, _py: Python,
worker_id: i64, worker_id: WorkerId,
kv_cache_event_bytes: &[u8], kv_cache_event_bytes: &[u8],
) -> PyResult<()> { ) -> PyResult<()> {
let kv_cache_event: llm_rs::kv_router::protocols::KvCacheEvent = let kv_cache_event: llm_rs::kv_router::protocols::KvCacheEvent =
...@@ -377,12 +392,12 @@ impl RadixTree { ...@@ -377,12 +392,12 @@ impl RadixTree {
Ok(()) Ok(())
} }
fn remove_worker(&mut self, _py: Python, worker_id: i64) -> PyResult<()> { fn remove_worker(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.remove_worker(worker_id); self.inner.remove_worker(worker_id);
Ok(()) Ok(())
} }
fn clear_all_blocks(&mut self, _py: Python, worker_id: i64) -> PyResult<()> { fn clear_all_blocks(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.clear_all_blocks(worker_id); self.inner.clear_all_blocks(worker_id);
Ok(()) Ok(())
} }
...@@ -517,16 +532,19 @@ impl ApproxKvIndexer { ...@@ -517,16 +532,19 @@ impl ApproxKvIndexer {
}) })
} }
#[pyo3(signature = (tokens, worker_id, dp_rank=0))]
fn process_routing_decision_for_request<'p>( fn process_routing_decision_for_request<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
tokens: Vec<u32>, tokens: Vec<u32>,
worker_id: i64, worker_id: WorkerId,
dp_rank: DpRank,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone(); let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker = llm_rs::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank);
indexer indexer
.process_routing_decision_for_request(tokens.as_slice(), worker_id) .process_routing_decision_for_request(tokens.as_slice(), worker)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(()) Ok(())
...@@ -538,7 +556,7 @@ impl ApproxKvIndexer { ...@@ -538,7 +556,7 @@ impl ApproxKvIndexer {
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct EndpointKvMetrics { pub(crate) struct EndpointKvMetrics {
#[pyo3(get, set)] #[pyo3(get, set)]
pub worker_id: i64, pub worker_id: WorkerId,
#[pyo3(get, set)] #[pyo3(get, set)]
pub request_active_slots: u64, pub request_active_slots: u64,
#[pyo3(get, set)] #[pyo3(get, set)]
...@@ -784,7 +802,7 @@ impl WorkerStats { ...@@ -784,7 +802,7 @@ impl WorkerStats {
request_active_slots: u64, request_active_slots: u64,
request_total_slots: u64, request_total_slots: u64,
num_requests_waiting: u64, num_requests_waiting: u64,
data_parallel_rank: Option<u32>, data_parallel_rank: Option<DpRank>,
) -> Self { ) -> Self {
Self(RsWorkerStats { Self(RsWorkerStats {
data_parallel_rank, data_parallel_rank,
...@@ -961,7 +979,7 @@ impl KvPushRouter { ...@@ -961,7 +979,7 @@ impl KvPushRouter {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, extra_args=None))] #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None))]
fn generate<'p>( fn generate<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
...@@ -971,7 +989,8 @@ impl KvPushRouter { ...@@ -971,7 +989,8 @@ impl KvPushRouter {
sampling_options: Option<PyObject>, sampling_options: Option<PyObject>,
output_options: Option<PyObject>, output_options: Option<PyObject>,
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
worker_id: Option<i64>, worker_id: Option<WorkerId>,
dp_rank: Option<DpRank>,
extra_args: Option<PyObject>, extra_args: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults // Depythonize the options with defaults
...@@ -1027,6 +1046,7 @@ impl KvPushRouter { ...@@ -1027,6 +1046,7 @@ impl KvPushRouter {
.sampling_options(sampling_options) .sampling_options(sampling_options)
.output_options(output_options) .output_options(output_options)
.router_config_override(router_config_override) .router_config_override(router_config_override)
.dp_rank(dp_rank)
.extra_args(extra_args); .extra_args(extra_args);
// Set backend_instance_id if worker_id is provided // Set backend_instance_id if worker_id is provided
...@@ -1053,6 +1073,43 @@ impl KvPushRouter { ...@@ -1053,6 +1073,43 @@ impl KvPushRouter {
Self::process_request_to_stream(py, self.inner.clone(), request) Self::process_request_to_stream(py, self.inner.clone(), request)
} }
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
fn best_worker<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
router_config_override: Option<PyObject>,
request_id: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
depythonize(obj.bind(py)).map_err(to_pyerr)?;
Ok::<_, PyErr>(Some(override_config))
})?
} else {
None
};
let chooser = self.inner.chooser.clone();
let update_states = request_id.is_some();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (best_worker, overlap_blocks) = chooser
.find_best_match(
request_id.as_deref(),
&token_ids,
router_config_override.as_ref(),
update_states,
)
.await
.map_err(to_pyerr)?;
Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
})
}
/// Deprecated: Use `best_worker()` instead which returns (worker_id, dp_rank, overlap_blocks)
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))] #[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
fn best_worker_id<'p>( fn best_worker_id<'p>(
&self, &self,
...@@ -1061,6 +1118,16 @@ impl KvPushRouter { ...@@ -1061,6 +1118,16 @@ impl KvPushRouter {
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
request_id: Option<String>, request_id: Option<String>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Issue deprecation warning
let warnings = py.import("warnings")?;
warnings.call_method1(
"warn",
(
"best_worker_id() is deprecated. Use best_worker() instead which returns (worker_id, dp_rank, overlap_blocks)",
py.get_type::<pyo3::exceptions::PyDeprecationWarning>(),
),
)?;
let router_config_override = if let Some(obj) = router_config_override { let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| { Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride = let override_config: llm_rs::kv_router::RouterConfigOverride =
...@@ -1075,7 +1142,7 @@ impl KvPushRouter { ...@@ -1075,7 +1142,7 @@ impl KvPushRouter {
let update_states = request_id.is_some(); let update_states = request_id.is_some();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = chooser let (best_worker, overlap_blocks) = chooser
.find_best_match( .find_best_match(
request_id.as_deref(), request_id.as_deref(),
&token_ids, &token_ids,
...@@ -1085,8 +1152,8 @@ impl KvPushRouter { ...@@ -1085,8 +1152,8 @@ impl KvPushRouter {
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
// Return a tuple of (worker_id, overlap_blocks) // Return only worker_id and overlap_blocks for backward compatibility
Ok((worker_id, overlap_blocks)) Ok((best_worker.worker_id, overlap_blocks))
}) })
} }
...@@ -1130,6 +1197,7 @@ impl KvPushRouter { ...@@ -1130,6 +1197,7 @@ impl KvPushRouter {
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
// Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
// Use pythonize to convert Vec<PotentialLoad> to Python list of dicts // Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
Python::with_gil(|py| { Python::with_gil(|py| {
pythonize(py, &loads) pythonize(py, &loads)
......
...@@ -44,6 +44,11 @@ impl ModelRuntimeConfig { ...@@ -44,6 +44,11 @@ impl ModelRuntimeConfig {
self.inner.reasoning_parser = reasoning_parser; self.inner.reasoning_parser = reasoning_parser;
} }
#[setter]
fn set_data_parallel_size(&mut self, data_parallel_size: u32) {
self.inner.data_parallel_size = data_parallel_size;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner self.inner
......
...@@ -778,16 +778,21 @@ class KvEventPublisher: ...@@ -778,16 +778,21 @@ class KvEventPublisher:
... ...
def __init__( def __init__(
self, component: Component, worker_id: int, kv_block_size: int self, component: Component, worker_id: int, kv_block_size: int, dp_rank: int = 0
) -> None: ) -> None:
""" """
Create a `KvEventPublisher` object Create a `KvEventPublisher` object
Args:
component: The component to publish events for
worker_id: The worker ID
kv_block_size: The KV block size (must be > 0)
dp_rank: The data parallel rank (defaults to 0)
""" """
def publish_stored( def publish_stored(
self, self,
event_id, event_id: int,
int,
token_ids: List[int], token_ids: List[int],
num_block_tokens: List[int], num_block_tokens: List[int],
block_hashes: List[int], block_hashes: List[int],
...@@ -796,12 +801,24 @@ class KvEventPublisher: ...@@ -796,12 +801,24 @@ class KvEventPublisher:
) -> None: ) -> None:
""" """
Publish a KV stored event. Publish a KV stored event.
Args:
event_id: The event ID
token_ids: List of token IDs
num_block_tokens: Number of tokens per block
block_hashes: List of block hashes (signed 64-bit integers)
lora_id: The LoRA ID
parent_hash: Optional parent hash (signed 64-bit integer)
""" """
... ...
def publish_removed(self, event_id, int, block_hashes: List[int]) -> None: def publish_removed(self, event_id: int, block_hashes: List[int]) -> None:
""" """
Publish a KV removed event. Publish a KV removed event.
Args:
event_id: The event ID
block_hashes: List of block hashes to remove (signed 64-bit integers)
""" """
... ...
...@@ -1199,6 +1216,7 @@ class KvPushRouter: ...@@ -1199,6 +1216,7 @@ class KvPushRouter:
output_options: Optional[JsonLike] = None, output_options: Optional[JsonLike] = None,
router_config_override: Optional[JsonLike] = None, router_config_override: Optional[JsonLike] = None,
worker_id: Optional[int] = None, worker_id: Optional[int] = None,
dp_rank: Optional[int] = None,
) -> AsyncIterator[JsonLike]: ) -> AsyncIterator[JsonLike]:
""" """
Generate text using the KV-aware router. Generate text using the KV-aware router.
...@@ -1213,6 +1231,10 @@ class KvPushRouter: ...@@ -1213,6 +1231,10 @@ class KvPushRouter:
worker_id: Optional worker ID to route to directly. If set, the request worker_id: Optional worker ID to route to directly. If set, the request
will be sent to this specific worker and router states will be will be sent to this specific worker and router states will be
updated accordingly. updated accordingly.
dp_rank: Optional data parallel rank to route to. If set along with worker_id,
the request will be routed to the specific (worker_id, dp_rank) pair.
If only dp_rank is set, the router will select the best worker but
force routing to the specified dp_rank.
Returns: Returns:
An async iterator yielding generation responses An async iterator yielding generation responses
...@@ -1220,10 +1242,36 @@ class KvPushRouter: ...@@ -1220,10 +1242,36 @@ class KvPushRouter:
Note: Note:
- If worker_id is set, the request bypasses KV matching and routes directly - If worker_id is set, the request bypasses KV matching and routes directly
to the specified worker while still updating router states. to the specified worker while still updating router states.
- dp_rank allows targeting a specific data parallel replica when workers have
multiple replicas (data_parallel_size > 1).
- This is different from query_instance_id which doesn't route the request. - This is different from query_instance_id which doesn't route the request.
""" """
... ...
async def best_worker(
self,
token_ids: List[int],
router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None,
) -> Tuple[int, int, int]:
"""
Find the best matching worker for the given tokens.
Args:
token_ids: List of token IDs to find matches for
router_config_override: Optional router configuration override
request_id: Optional request ID. If provided, router states will be updated
to track this request (active blocks, lifecycle events). If not
provided, this is a query-only operation that doesn't affect state.
Returns:
A tuple of (worker_id, dp_rank, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- dp_rank: The data parallel rank of the selected worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def best_worker_id( async def best_worker_id(
self, self,
token_ids: List[int], token_ids: List[int],
...@@ -1231,6 +1279,8 @@ class KvPushRouter: ...@@ -1231,6 +1279,8 @@ class KvPushRouter:
request_id: Optional[str] = None, request_id: Optional[str] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
""" """
[DEPRECATED] Use best_worker() instead which returns (worker_id, dp_rank, overlap_blocks).
Find the best matching worker for the given tokens. Find the best matching worker for the given tokens.
Args: Args:
...@@ -1244,6 +1294,9 @@ class KvPushRouter: ...@@ -1244,6 +1294,9 @@ class KvPushRouter:
A tuple of (worker_id, overlap_blocks) where: A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker - worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found - overlap_blocks: The number of overlapping blocks found
.. deprecated::
Use :meth:`best_worker` instead which also returns dp_rank.
""" """
... ...
...@@ -1260,8 +1313,13 @@ class KvPushRouter: ...@@ -1260,8 +1313,13 @@ class KvPushRouter:
Returns: Returns:
A list of dictionaries, each containing: A list of dictionaries, each containing:
- worker_id: The worker ID - worker_id: The worker ID
- dp_rank: The data parallel rank
- potential_prefill_tokens: Number of tokens that would need prefill - potential_prefill_tokens: Number of tokens that would need prefill
- potential_decode_blocks: Number of blocks currently in decode phase - potential_decode_blocks: Number of blocks currently in decode phase
Note:
Each (worker_id, dp_rank) pair is returned as a separate entry.
If you need aggregated loads per worker_id, sum the values manually.
""" """
... ...
...@@ -1287,7 +1345,7 @@ class KvPushRouter: ...@@ -1287,7 +1345,7 @@ class KvPushRouter:
Note: Note:
This is typically called automatically by the router when using the This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using `generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing. `best_worker()` with `request_id` for custom routing.
""" """
... ...
...@@ -1304,7 +1362,7 @@ class KvPushRouter: ...@@ -1304,7 +1362,7 @@ class KvPushRouter:
Note: Note:
This is typically called automatically by the router when using the This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using `generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing. `best_worker()` with `request_id` for custom routing.
""" """
... ...
......
...@@ -85,17 +85,21 @@ async def test_radix_tree_binding(distributed_runtime): ...@@ -85,17 +85,21 @@ async def test_radix_tree_binding(distributed_runtime):
overlap_scores = radix_tree.find_matches([0]) overlap_scores = radix_tree.find_matches([0])
# Verify the results # Verify the results
# Note: scores is now Dict[(worker_id, dp_rank), score]
assert overlap_scores.scores is not None assert overlap_scores.scores is not None
assert ( assert (
len(overlap_scores.scores) == 1 len(overlap_scores.scores) == 1
), f"Expected 1 worker in scores, got {len(overlap_scores.scores)}" ), f"Expected 1 worker in scores, got {len(overlap_scores.scores)}"
assert worker_id in overlap_scores.scores, f"Worker {worker_id} not found in scores" worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert ( assert (
overlap_scores.scores[worker_id] == 1 worker_key in overlap_scores.scores
), f"Expected score 1 for worker {worker_id}, got {overlap_scores.scores[worker_id]}" ), f"Worker {worker_key} not found in scores"
assert (
overlap_scores.scores[worker_key] == 1
), f"Expected score 1 for worker {worker_key}, got {overlap_scores.scores[worker_key]}"
print( print(
f"✓ RadixTree test passed: worker {worker_id} has score {overlap_scores.scores[worker_id]}" f"✓ RadixTree test passed: worker {worker_key} has score {overlap_scores.scores[worker_key]}"
) )
...@@ -130,24 +134,25 @@ async def test_event_handler(distributed_runtime): ...@@ -130,24 +134,25 @@ async def test_event_handler(distributed_runtime):
event_publisher.store_event(test_token, lora_id) event_publisher.store_event(test_token, lora_id)
# wait for the event to be processed as it is sent asynchronously # wait for the event to be processed as it is sent asynchronously
# Retry loop for CI environments where processing may take longer # Retry loop for CI environments where processing may take longer
worker_key = (worker_id, 0) # (worker_id, dp_rank)
for retry in range(10): # Try up to 10 times for retry in range(10): # Try up to 10 times
await asyncio.sleep(0.5) # Wait 500ms between retries await asyncio.sleep(0.5) # Wait 500ms between retries
scores = await indexer.find_matches_for_request(test_token, lora_id) scores = await indexer.find_matches_for_request(test_token, lora_id)
if ( if (
scores.scores scores.scores
and worker_id in scores.scores and worker_key in scores.scores
and scores.scores[worker_id] == 1 and scores.scores[worker_key] == 1
): ):
break break
if retry == 9: # Last iteration if retry == 9: # Last iteration
# Provide detailed error message for debugging # Provide detailed error message for debugging
assert scores.scores, f"No scores found after {(retry+1)*0.5}s" assert scores.scores, f"No scores found after {(retry+1)*0.5}s"
assert ( assert (
worker_id in scores.scores worker_key in scores.scores
), f"Worker {worker_id} not in scores after {(retry+1)*0.5}s" ), f"Worker {worker_key} not in scores after {(retry+1)*0.5}s"
assert ( assert (
scores.scores[worker_id] == 1 scores.scores[worker_key] == 1
), f"Expected score 1, got {scores.scores.get(worker_id)} after {(retry+1)*0.5}s" ), f"Expected score 1, got {scores.scores.get(worker_key)} after {(retry+1)*0.5}s"
# remove event # remove event
event_publisher.remove_event() event_publisher.remove_event()
...@@ -185,8 +190,9 @@ async def test_approx_kv_indexer(distributed_runtime): ...@@ -185,8 +190,9 @@ async def test_approx_kv_indexer(distributed_runtime):
scores = await indexer.find_matches_for_request(tokens) scores = await indexer.find_matches_for_request(tokens)
assert scores.scores assert scores.scores
assert worker_id in scores.scores worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert scores.scores[worker_id] == 2 assert worker_key in scores.scores
assert scores.scores[worker_key] == 2
class EventPublisher: class EventPublisher:
...@@ -281,7 +287,7 @@ async def metrics_publisher_task(kv_listener, expected_metrics): ...@@ -281,7 +287,7 @@ async def metrics_publisher_task(kv_listener, expected_metrics):
expected_metrics["request_active_slots"], expected_metrics["request_active_slots"],
expected_metrics["request_total_slots"], expected_metrics["request_total_slots"],
expected_metrics["num_requests_waiting"], expected_metrics["num_requests_waiting"],
None, 0, # data_parallel_rank (0 = DP not enabled)
) )
kv_stats = KvStats( kv_stats = KvStats(
......
...@@ -38,7 +38,9 @@ use crate::{ ...@@ -38,7 +38,9 @@ use crate::{
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent, KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block, compute_block_hash_for_seq, compute_seq_hash_for_block,
}, },
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{
LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
subscriber::start_kv_router_background, subscriber::start_kv_router_background,
...@@ -74,7 +76,7 @@ pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock"; ...@@ -74,7 +76,7 @@ pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";
pub trait WorkerSelector { pub trait WorkerSelector {
fn select_worker( fn select_worker(
&self, &self,
workers: &HashMap<i64, Option<ModelRuntimeConfig>>, workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: u32, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>; ) -> Result<WorkerSelectionResult, KvSchedulerError>;
...@@ -316,7 +318,7 @@ impl KvRouter { ...@@ -316,7 +318,7 @@ impl KvRouter {
} }
/// Give these tokens, find the worker with the best match in it's KV cache. /// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks. /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking /// Now also takes optional context_id for request tracking
pub async fn find_best_match( pub async fn find_best_match(
&self, &self,
...@@ -324,7 +326,7 @@ impl KvRouter { ...@@ -324,7 +326,7 @@ impl KvRouter {
tokens: &[u32], tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
) -> anyhow::Result<(i64, u32)> { ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
// Validate that context_id is provided when update_states is true // Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() { if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true"); panic!("context_id must be provided if update_states is true");
...@@ -350,7 +352,7 @@ impl KvRouter { ...@@ -350,7 +352,7 @@ impl KvRouter {
(false, false) => (None, None), (false, false) => (None, None),
}; };
let best_worker_id = self let best_worker = self
.scheduler .scheduler
.schedule( .schedule(
context_id.map(|s| s.to_string()), context_id.map(|s| s.to_string()),
...@@ -364,17 +366,17 @@ impl KvRouter { ...@@ -364,17 +366,17 @@ impl KvRouter {
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer { if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer indexer
.process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap()) .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
.await .await
.unwrap(); .unwrap();
}; };
let overlap_amount = overlap_scores let overlap_amount = overlap_scores
.scores .scores
.get(&best_worker_id) .get(&best_worker)
.copied() .copied()
.unwrap_or(0); .unwrap_or(0);
Ok((best_worker_id, overlap_amount)) Ok((best_worker, overlap_amount))
} }
pub async fn add_request( pub async fn add_request(
...@@ -382,7 +384,7 @@ impl KvRouter { ...@@ -382,7 +384,7 @@ impl KvRouter {
request_id: String, request_id: String,
tokens: &[u32], tokens: &[u32],
overlap_blocks: u32, overlap_blocks: u32,
worker_id: i64, worker: WorkerWithDpRank,
) { ) {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
...@@ -397,7 +399,7 @@ impl KvRouter { ...@@ -397,7 +399,7 @@ impl KvRouter {
maybe_seq_hashes, maybe_seq_hashes,
isl_tokens, isl_tokens,
overlap_blocks, overlap_blocks,
worker_id, worker,
) )
.await; .await;
} }
...@@ -450,12 +452,13 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -450,12 +452,13 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
// Handle different request types // Handle different request types
let response = match request { let response = match request {
RouterRequest::New { tokens } => { RouterRequest::New { tokens } => {
let (worker_id, overlap_blocks) = self let (best_worker, overlap_blocks) = self
.find_best_match(Some(&context_id), &tokens, None, true) .find_best_match(Some(&context_id), &tokens, None, true)
.await?; .await?;
RouterResponse::New { RouterResponse::New {
worker_id, worker_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
overlap_blocks, overlap_blocks,
} }
} }
...@@ -523,24 +526,45 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -523,24 +526,45 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Check if this is a query_instance_id request first // Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id"); let query_instance_id = request.has_annotation("query_instance_id");
let (instance_id, overlap_amount) = if let Some(id) = request.backend_instance_id { let (instance_id, dp_rank, overlap_amount) = if let Some(id) =
// If instance_id is set, use it and manually add the request to track it request.backend_instance_id
if !query_instance_id { {
// If instance_id is set, use it and compute actual overlap
let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_id {
tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
);
}
// Compute actual overlap blocks by querying the indexer
let block_hashes =
compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);
self.chooser self.chooser
.add_request(context_id.clone(), &request.token_ids, 0, id) .add_request(
context_id.clone(),
&request.token_ids,
overlap_blocks,
worker,
)
.await; .await;
} (id, dp_rank, overlap_blocks)
(id, 0)
} else { } else {
// Otherwise, find the best match // Otherwise, find the best match
self.chooser let (best_worker, overlap_amount) = self
.chooser
.find_best_match( .find_best_match(
Some(&context_id), Some(&context_id),
&request.token_ids, &request.token_ids,
request.router_config_override.as_ref(), request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id !query_instance_id, // Don't update states if query_instance_id
) )
.await? .await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount)
}; };
// if request has the annotation "query_instance_id", // if request has the annotation "query_instance_id",
...@@ -564,6 +588,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -564,6 +588,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?; let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
......
...@@ -27,11 +27,11 @@ use crate::tokens::{SequenceHash, TokenBlockSequence}; ...@@ -27,11 +27,11 @@ use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::indexer::{ use crate::kv_router::indexer::{
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent,
WorkerId, compute_block_hash_for_seq, compute_block_hash_for_seq,
}; };
use crate::kv_router::protocols::{ use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, KvCacheStoredBlockData, LocalBlockHash, WorkerId, WorkerWithDpRank,
}; };
#[derive(Debug)] #[derive(Debug)]
...@@ -44,8 +44,8 @@ struct MatchRequest { ...@@ -44,8 +44,8 @@ struct MatchRequest {
#[derive(Debug)] #[derive(Debug)]
struct RouterResult { struct RouterResult {
/// The id of the selected worker. /// The worker (with dp_rank) that was selected.
worker_id: WorkerId, worker: WorkerWithDpRank,
/// The local hashes of the tokens sent to the worker. /// The local hashes of the tokens sent to the worker.
local_hashes: Vec<LocalBlockHash>, local_hashes: Vec<LocalBlockHash>,
...@@ -58,8 +58,8 @@ struct RouterResult { ...@@ -58,8 +58,8 @@ struct RouterResult {
struct TimerEntry { struct TimerEntry {
/// The key of the timer. /// The key of the timer.
key: ExternalSequenceBlockHash, key: ExternalSequenceBlockHash,
/// The worker id that stored this block. /// The worker (with dp_rank) that stored this block.
worker: WorkerId, worker: WorkerWithDpRank,
} }
/// A data structure to manage a collection of timers, addressable by a key. /// A data structure to manage a collection of timers, addressable by a key.
...@@ -237,10 +237,11 @@ impl ApproxKvIndexer { ...@@ -237,10 +237,11 @@ impl ApproxKvIndexer {
event_id += 1; event_id += 1;
let event = RouterEvent::new( let event = RouterEvent::new(
result.worker_id, result.worker.worker_id,
KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: stored_event, data: stored_event,
dp_rank: result.worker.dp_rank,
} }
); );
...@@ -248,7 +249,7 @@ impl ApproxKvIndexer { ...@@ -248,7 +249,7 @@ impl ApproxKvIndexer {
timer_manager.insert(result.sequence_hashes.iter().map(|h| TimerEntry { timer_manager.insert(result.sequence_hashes.iter().map(|h| TimerEntry {
key: ExternalSequenceBlockHash(*h), key: ExternalSequenceBlockHash(*h),
worker: result.worker_id, worker: result.worker,
}).collect()); }).collect());
} }
...@@ -269,12 +270,13 @@ impl ApproxKvIndexer { ...@@ -269,12 +270,13 @@ impl ApproxKvIndexer {
event_id += 1; event_id += 1;
let event = RouterEvent::new( let event = RouterEvent::new(
e.worker, e.worker.worker_id,
KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![e.key], block_hashes: vec![e.key],
}), }),
dp_rank: e.worker.dp_rank,
} }
); );
...@@ -307,13 +309,13 @@ impl ApproxKvIndexer { ...@@ -307,13 +309,13 @@ impl ApproxKvIndexer {
/// Core function to process a routing decision with pre-computed hashes /// Core function to process a routing decision with pre-computed hashes
pub async fn process_routing_decision( pub async fn process_routing_decision(
&self, &self,
worker_id: WorkerId, worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>, local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>, sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
self.route_tx self.route_tx
.send(RouterResult { .send(RouterResult {
worker_id, worker,
local_hashes, local_hashes,
sequence_hashes, sequence_hashes,
}) })
...@@ -327,7 +329,7 @@ impl ApproxKvIndexer { ...@@ -327,7 +329,7 @@ impl ApproxKvIndexer {
pub async fn process_routing_decision_for_request( pub async fn process_routing_decision_for_request(
&self, &self,
tokens: &[u32], tokens: &[u32],
worker_id: WorkerId, worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size); let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size);
...@@ -338,7 +340,7 @@ impl ApproxKvIndexer { ...@@ -338,7 +340,7 @@ impl ApproxKvIndexer {
.map(|b| b.sequence_hash()) .map(|b| b.sequence_hash())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.process_routing_decision(worker_id, local_hashes, sequence_hashes) self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await .await
} }
} }
...@@ -526,14 +528,20 @@ mod tests { ...@@ -526,14 +528,20 @@ mod tests {
// 2. Inform indexer about routing decision // 2. Inform indexer about routing decision
indexer indexer
.process_routing_decision_for_request(&tokens, worker_id) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await .await
.unwrap(); .unwrap();
// Poll until we observe the match being registered // Poll until we observe the match being registered
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_id).copied() == Some(1) s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_id))
.copied()
== Some(1)
}) })
.await; .await;
...@@ -554,14 +562,18 @@ mod tests { ...@@ -554,14 +562,18 @@ mod tests {
let worker_id: WorkerId = 7; let worker_id: WorkerId = 7;
indexer indexer
.process_routing_decision_for_request(&tokens, worker_id) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await .await
.unwrap(); .unwrap();
// Wait until the worker is registered // Wait until the worker is registered
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.contains_key(&worker_id) s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
}) })
.await; .await;
...@@ -571,7 +583,8 @@ mod tests { ...@@ -571,7 +583,8 @@ mod tests {
// Ensure the worker's entries are gone // Ensure the worker's entries are gone
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
!s.scores.contains_key(&worker_id) !s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
}) })
.await; .await;
} }
...@@ -590,19 +603,31 @@ mod tests { ...@@ -590,19 +603,31 @@ mod tests {
// Register on both workers // Register on both workers
indexer indexer
.process_routing_decision_for_request(&tokens, worker_0) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await .await
.unwrap(); .unwrap();
indexer indexer
.process_routing_decision_for_request(&tokens, worker_1) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await .await
.unwrap(); .unwrap();
// Ensure both workers are registered // Ensure both workers are registered
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_0).copied() == Some(1) s.scores
&& s.scores.get(&worker_1).copied() == Some(1) .get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
== Some(1)
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
}) })
.await; .await;
...@@ -612,7 +637,12 @@ mod tests { ...@@ -612,7 +637,12 @@ mod tests {
// Confirm the removed worker is gone, and the other remains. // Confirm the removed worker is gone, and the other remains.
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
!s.scores.contains_key(&worker_0) && s.scores.get(&worker_1).copied() == Some(1) !s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
}) })
.await; .await;
} }
...@@ -631,14 +661,20 @@ mod tests { ...@@ -631,14 +661,20 @@ mod tests {
// Register Sequence A on worker A // Register Sequence A on worker A
indexer indexer
.process_routing_decision_for_request(&seq_a, worker_a) .process_routing_decision_for_request(
&seq_a,
WorkerWithDpRank::from_worker_id(worker_a),
)
.await .await
.unwrap(); .unwrap();
// Ensure the indexer has registered the block // Ensure the indexer has registered the block
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&seq_a).await.unwrap(); let s = indexer.find_matches_for_request(&seq_a).await.unwrap();
s.scores.get(&worker_a).copied() == Some(1) s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a))
.copied()
== Some(1)
}) })
.await; .await;
...@@ -649,7 +685,12 @@ mod tests { ...@@ -649,7 +685,12 @@ mod tests {
let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap(); let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap();
// Expect worker A to have an overlap score of 1 (shared first block) // Expect worker A to have an overlap score of 1 (shared first block)
assert_eq!(overlap.scores.get(&worker_a), Some(&1)); assert_eq!(
overlap
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a)),
Some(&1)
);
} }
/// When the same block resides on multiple workers, all should appear in the overlap scores. /// When the same block resides on multiple workers, all should appear in the overlap scores.
...@@ -666,25 +707,47 @@ mod tests { ...@@ -666,25 +707,47 @@ mod tests {
// Register the same sequence on two different workers // Register the same sequence on two different workers
indexer indexer
.process_routing_decision_for_request(&tokens, worker_0) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await .await
.unwrap(); .unwrap();
indexer indexer
.process_routing_decision_for_request(&tokens, worker_1) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await .await
.unwrap(); .unwrap();
// Wait until both workers are reflected in overlap scores // Wait until both workers are reflected in overlap scores
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_0).copied() == Some(1) s.scores
&& s.scores.get(&worker_1).copied() == Some(1) .get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
== Some(1)
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
}) })
.await; .await;
let scores = indexer.find_matches_for_request(&tokens).await.unwrap(); let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
assert_eq!(scores.scores.get(&worker_0), Some(&1)); assert_eq!(
assert_eq!(scores.scores.get(&worker_1), Some(&1)); scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0)),
Some(&1)
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1)),
Some(&1)
);
} }
} }
This diff is collapsed.
...@@ -5,6 +5,34 @@ use crate::tokens::{SequenceHash, Token}; ...@@ -5,6 +5,34 @@ use crate::tokens::{SequenceHash, Token};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
/// A worker identifier.
pub type WorkerId = i64;
/// A data parallel rank identifier.
pub type DpRank = u32;
/// A worker identifier combined with its data parallel rank.
/// Used for routing decisions in data parallel setups.
/// dp_rank = 0 indicates either DP not enabled or the first rank.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct WorkerWithDpRank {
pub worker_id: WorkerId,
pub dp_rank: DpRank,
}
impl WorkerWithDpRank {
pub fn new(worker_id: WorkerId, dp_rank: DpRank) -> Self {
Self { worker_id, dp_rank }
}
pub fn from_worker_id(worker_id: WorkerId) -> Self {
Self {
worker_id,
dp_rank: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")] #[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterRequest { pub enum RouterRequest {
...@@ -26,15 +54,24 @@ impl Default for RouterRequest { ...@@ -26,15 +54,24 @@ impl Default for RouterRequest {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")] #[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterResponse { pub enum RouterResponse {
New { worker_id: i64, overlap_blocks: u32 }, New {
PrefillMarked { success: bool }, worker_id: WorkerId,
FreeMarked { success: bool }, #[serde(default)]
dp_rank: DpRank,
overlap_blocks: u32,
},
PrefillMarked {
success: bool,
},
FreeMarked {
success: bool,
},
} }
#[derive(Debug)] #[derive(Debug)]
pub struct WorkerSelectionResult { pub struct WorkerSelectionResult {
/// The worker id of the selected worker /// The full worker information including dp_rank
pub worker_id: i64, pub worker: WorkerWithDpRank,
/// The total number of blocks required to prefill the request /// The total number of blocks required to prefill the request
pub required_blocks: u64, pub required_blocks: u64,
...@@ -54,7 +91,7 @@ pub struct ForwardPassMetrics { ...@@ -54,7 +91,7 @@ pub struct ForwardPassMetrics {
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct WorkerStats { pub struct WorkerStats {
// https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models // https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
pub data_parallel_rank: Option<u32>, pub data_parallel_rank: Option<DpRank>,
pub request_active_slots: u64, pub request_active_slots: u64,
pub request_total_slots: u64, pub request_total_slots: u64,
pub num_requests_waiting: u64, pub num_requests_waiting: u64,
...@@ -136,7 +173,7 @@ impl From<i64> for ExternalSequenceBlockHash { ...@@ -136,7 +173,7 @@ impl From<i64> for ExternalSequenceBlockHash {
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillEvent { pub struct PrefillEvent {
pub request_id: String, pub request_id: String,
pub worker_id: i64, pub worker_id: WorkerId,
pub data: PrefillEventData, pub data: PrefillEventData,
pub router_id: Uuid, pub router_id: Uuid,
} }
...@@ -155,7 +192,7 @@ pub enum PrefillEventData { ...@@ -155,7 +192,7 @@ pub enum PrefillEventData {
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ActiveSequenceEvent { pub struct ActiveSequenceEvent {
pub request_id: String, pub request_id: String,
pub worker_id: i64, pub worker: WorkerWithDpRank,
pub data: ActiveSequenceEventData, pub data: ActiveSequenceEventData,
pub router_id: Uuid, pub router_id: Uuid,
} }
...@@ -199,6 +236,9 @@ pub struct KvCacheEvent { ...@@ -199,6 +236,9 @@ pub struct KvCacheEvent {
pub event_id: u64, pub event_id: u64,
/// The data associated with the event. /// The data associated with the event.
pub data: KvCacheEventData, pub data: KvCacheEventData,
/// The data parallel rank of the worker emitting this event (0 if DP not enabled).
#[serde(default)]
pub dp_rank: DpRank,
} }
/// Represents the data associated with a cache event. /// Represents the data associated with a cache event.
...@@ -313,6 +353,7 @@ mod tests { ...@@ -313,6 +353,7 @@ mod tests {
let event = KvCacheEvent { let event = KvCacheEvent {
event_id: 1, event_id: 1,
data: event_data, data: event_data,
dp_rank: 0,
}; };
let events = KvCacheEvents { let events = KvCacheEvents {
......
...@@ -326,13 +326,16 @@ pub async fn start_zmq_listener( ...@@ -326,13 +326,16 @@ pub async fn start_zmq_listener(
}; };
tracing::trace!( tracing::trace!(
"ZMQ listener on {} received batch with {} events (seq={})", "ZMQ listener on {} received batch with {} events (seq={}, dp_rank={})",
zmq_endpoint, zmq_endpoint,
batch.events.len(), batch.events.len(),
seq seq,
batch.data_parallel_rank
); );
let dp_rank = batch.data_parallel_rank;
for raw_event in batch.events.into_iter() { for raw_event in batch.events.into_iter() {
let event = convert_event(raw_event, seq, kv_block_size, &warning_count); let event = convert_event(raw_event, seq, kv_block_size, dp_rank, &warning_count);
if tx.send(event).is_err() { if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped"); tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped"; exit_reason = "channel receiver dropped";
...@@ -356,6 +359,7 @@ fn convert_event( ...@@ -356,6 +359,7 @@ fn convert_event(
raw: RawKvEvent, raw: RawKvEvent,
event_id: u64, event_id: u64,
kv_block_size: u32, kv_block_size: u32,
dp_rank: u32,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent { ) -> KvCacheEvent {
match raw { match raw {
...@@ -387,6 +391,7 @@ fn convert_event( ...@@ -387,6 +391,7 @@ fn convert_event(
warning_count, warning_count,
), ),
}), }),
dp_rank,
} }
} }
RawKvEvent::BlockRemoved { block_hashes, .. } => { RawKvEvent::BlockRemoved { block_hashes, .. } => {
...@@ -400,11 +405,13 @@ fn convert_event( ...@@ -400,11 +405,13 @@ fn convert_event(
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes, block_hashes: hashes,
}), }),
dp_rank,
} }
} }
RawKvEvent::AllBlocksCleared => KvCacheEvent { RawKvEvent::AllBlocksCleared => KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Cleared, data: KvCacheEventData::Cleared,
dp_rank,
}, },
} }
} }
...@@ -1014,7 +1021,7 @@ mod test_event_processing { ...@@ -1014,7 +1021,7 @@ mod test_event_processing {
medium: None, medium: None,
}; };
let out = convert_event(raw_evt, 42, kv_block_size, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Stored(_))); assert!(matches!(out.data, KvCacheEventData::Stored(_)));
} }
...@@ -1025,7 +1032,7 @@ mod test_event_processing { ...@@ -1025,7 +1032,7 @@ mod test_event_processing {
block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)], block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)],
medium: None, medium: None,
}; };
let out = convert_event(raw_evt, 7, kv_block_size, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 7, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Removed(_))); assert!(matches!(out.data, KvCacheEventData::Removed(_)));
} }
...@@ -1034,7 +1041,7 @@ mod test_event_processing { ...@@ -1034,7 +1041,7 @@ mod test_event_processing {
fn test_convert_event_all_blocks_cleared() { fn test_convert_event_all_blocks_cleared() {
let kv_block_size = 4; let kv_block_size = 4;
let raw_evt = RawKvEvent::AllBlocksCleared; let raw_evt = RawKvEvent::AllBlocksCleared;
let out = convert_event(raw_evt, 1, kv_block_size, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Cleared)); assert!(matches!(out.data, KvCacheEventData::Cleared));
} }
} }
...@@ -1115,6 +1122,7 @@ mod tests_startup_helpers { ...@@ -1115,6 +1122,7 @@ mod tests_startup_helpers {
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)], block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)],
}), }),
dp_rank: 0,
}; };
let token = CancellationToken::new(); let token = CancellationToken::new();
......
...@@ -12,7 +12,6 @@ mod tests { ...@@ -12,7 +12,6 @@ mod tests {
use super::*; use super::*;
use crate::kv_router::indexer::KvIndexer; use crate::kv_router::indexer::KvIndexer;
use crate::kv_router::indexer::KvIndexerMetrics; use crate::kv_router::indexer::KvIndexerMetrics;
use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::*; use crate::kv_router::protocols::*;
use std::time::Duration; use std::time::Duration;
use tempfile::tempdir; use tempfile::tempdir;
...@@ -50,6 +49,7 @@ mod tests { ...@@ -50,6 +49,7 @@ mod tests {
KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: add_blocks(hashes, parent), data: add_blocks(hashes, parent),
dp_rank: 0,
}, },
) )
} }
...@@ -65,6 +65,7 @@ mod tests { ...@@ -65,6 +65,7 @@ mod tests {
.map(|i| ExternalSequenceBlockHash(*i * 100)) .map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(), .collect(),
}), }),
dp_rank: 0,
}, },
) )
} }
......
...@@ -18,21 +18,24 @@ use super::KvRouterConfig; ...@@ -18,21 +18,24 @@ use super::KvRouterConfig;
use super::RouterConfigOverride; use super::RouterConfigOverride;
use super::WorkerSelector; use super::WorkerSelector;
use super::indexer::OverlapScores; use super::indexer::OverlapScores;
use super::protocols::WorkerSelectionResult; use super::protocols::{DpRank, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::sequence::ActiveSequencesMultiWorker; use super::sequence::ActiveSequencesMultiWorker;
use crate::tokens::SequenceHash; use crate::tokens::SequenceHash;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent { pub struct KVHitRateEvent {
pub worker_id: i64, pub worker_id: WorkerId,
#[serde(default)]
pub dp_rank: DpRank,
pub isl_blocks: usize, pub isl_blocks: usize,
pub overlap_blocks: u32, pub overlap_blocks: u32,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad { pub struct PotentialLoad {
pub worker_id: i64, pub worker_id: WorkerId,
pub dp_rank: DpRank,
pub potential_prefill_tokens: usize, pub potential_prefill_tokens: usize,
pub potential_decode_blocks: usize, pub potential_decode_blocks: usize,
} }
...@@ -51,7 +54,7 @@ pub enum KvSchedulerError { ...@@ -51,7 +54,7 @@ pub enum KvSchedulerError {
#[derive(Debug)] #[derive(Debug)]
pub struct SchedulingResponse { pub struct SchedulingResponse {
pub best_worker_id: i64, pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32, pub overlap_blocks: u32,
} }
...@@ -60,8 +63,8 @@ pub struct SchedulingRequest { ...@@ -60,8 +63,8 @@ pub struct SchedulingRequest {
pub token_seq: Option<Vec<SequenceHash>>, pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize, pub isl_tokens: usize,
pub overlaps: OverlapScores, pub overlaps: OverlapScores,
pub decode_blocks: HashMap<i64, usize>, pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: HashMap<i64, usize>, pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
// Router config overrides for this specific request // Router config overrides for this specific request
pub router_config_override: Option<RouterConfigOverride>, pub router_config_override: Option<RouterConfigOverride>,
// Whether to update scheduler states (false for query_instance_id requests) // Whether to update scheduler states (false for query_instance_id requests)
...@@ -94,17 +97,18 @@ impl KvScheduler { ...@@ -94,17 +97,18 @@ impl KvScheduler {
component: Component, component: Component,
block_size: u32, block_size: u32,
instances_rx: watch::Receiver<Vec<Instance>>, instances_rx: watch::Receiver<Vec<Instance>>,
runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>, runtime_configs_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool, replica_sync: bool,
router_uuid: String, router_uuid: String,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone(); let instances: Vec<Instance> = instances_rx.borrow().clone();
let runtime_configs: HashMap<i64, ModelRuntimeConfig> = runtime_configs_rx.borrow().clone(); let runtime_configs: HashMap<WorkerId, ModelRuntimeConfig> =
runtime_configs_rx.borrow().clone();
// Create shared workers_with_configs wrapped in Arc<RwLock> // Create shared workers_with_configs wrapped in Arc<RwLock>
let workers_with_configs: Arc<RwLock<HashMap<i64, Option<ModelRuntimeConfig>>>> = { let workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>> = {
let mut initial_map = HashMap::new(); let mut initial_map = HashMap::new();
for instance in &instances { for instance in &instances {
let worker_id = instance.instance_id; let worker_id = instance.instance_id;
...@@ -117,14 +121,10 @@ impl KvScheduler { ...@@ -117,14 +121,10 @@ impl KvScheduler {
Arc::new(RwLock::new(initial_map)) Arc::new(RwLock::new(initial_map))
}; };
let worker_ids: Vec<i64> = instances
.iter()
.map(|instance| instance.instance_id)
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new( let slots = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(), component.clone(),
block_size as usize, block_size as usize,
worker_ids, workers_with_configs.read().await.clone(), // this includes dp_size info
replica_sync, replica_sync,
router_uuid, router_uuid,
)); ));
...@@ -162,24 +162,23 @@ impl KvScheduler { ...@@ -162,24 +162,23 @@ impl KvScheduler {
let new_instances = instances_monitor_rx.borrow_and_update().clone(); let new_instances = instances_monitor_rx.borrow_and_update().clone();
let new_configs = configs_monitor_rx.borrow_and_update().clone(); let new_configs = configs_monitor_rx.borrow_and_update().clone();
// Update workers when instances change // Build the new workers_with_configs map
let worker_ids: Vec<i64> = new_instances let mut new_workers_with_configs = HashMap::new();
.iter()
.map(|instance| instance.instance_id)
.collect();
slots_monitor.update_workers(worker_ids);
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
workers_map.clear();
for instance in &new_instances { for instance in &new_instances {
let worker_id = instance.instance_id; let worker_id = instance.instance_id;
let config = new_configs.get(&worker_id).cloned(); let config = new_configs.get(&worker_id).cloned();
if config.is_some() { if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id); tracing::info!("Runtime config found for worker_id: {}", worker_id);
} }
workers_map.insert(worker_id, config); new_workers_with_configs.insert(worker_id, config);
} }
// Update workers when instances change
slots_monitor.update_workers(new_workers_with_configs.clone());
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
*workers_map = new_workers_with_configs;
tracing::trace!( tracing::trace!(
"Updated workers_with_configs with {} workers", "Updated workers_with_configs with {} workers",
workers_map.len() workers_map.len()
...@@ -229,7 +228,8 @@ impl KvScheduler { ...@@ -229,7 +228,8 @@ impl KvScheduler {
match selector.select_worker(&workers, &request, block_size) { match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => { Ok(selection) => {
let event = KVHitRateEvent { let event = KVHitRateEvent {
worker_id: selection.worker_id, worker_id: selection.worker.worker_id,
dp_rank: selection.worker.dp_rank,
isl_blocks: selection.required_blocks as usize, isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks, overlap_blocks: selection.overlap_blocks,
}; };
...@@ -238,7 +238,7 @@ impl KvScheduler { ...@@ -238,7 +238,7 @@ impl KvScheduler {
} }
let response = SchedulingResponse { let response = SchedulingResponse {
best_worker_id: selection.worker_id, best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks, overlap_blocks: selection.overlap_blocks,
}; };
request.respond(response); request.respond(response);
...@@ -261,7 +261,7 @@ impl KvScheduler { ...@@ -261,7 +261,7 @@ impl KvScheduler {
request.token_seq, request.token_seq,
request.isl_tokens, request.isl_tokens,
selection.overlap_blocks, selection.overlap_blocks,
selection.worker_id, selection.worker,
) )
.await .await
{ {
...@@ -302,7 +302,7 @@ impl KvScheduler { ...@@ -302,7 +302,7 @@ impl KvScheduler {
overlaps: OverlapScores, overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
) -> Result<i64, KvSchedulerError> { ) -> Result<WorkerWithDpRank, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
maybe_request_id, maybe_request_id,
...@@ -324,8 +324,7 @@ impl KvScheduler { ...@@ -324,8 +324,7 @@ impl KvScheduler {
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
let best_worker_id = response.best_worker_id; Ok(response.best_worker)
Ok(best_worker_id)
} }
pub async fn add_request( pub async fn add_request(
...@@ -334,11 +333,11 @@ impl KvScheduler { ...@@ -334,11 +333,11 @@ impl KvScheduler {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
worker_id: i64, worker: WorkerWithDpRank,
) { ) {
let _ = self let _ = self
.slots .slots
.add_request(request_id, token_sequence, isl, overlap, worker_id) .add_request(request_id, token_sequence, isl, overlap, worker)
.await; .await;
} }
...@@ -363,21 +362,22 @@ impl KvScheduler { ...@@ -363,21 +362,22 @@ impl KvScheduler {
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps) .potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
.await; .await;
// Get all unique worker IDs from both hashmaps // Get all unique WorkerWithDpRank from both hashmaps
let mut worker_ids: HashSet<i64> = HashSet::new(); let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
worker_ids.extend(decode_blocks.keys().copied()); workers.extend(decode_blocks.keys().copied());
worker_ids.extend(prefill_tokens.keys().copied()); workers.extend(prefill_tokens.keys().copied());
// Create PotentialLoad for each worker // Create PotentialLoad for each worker
let mut loads = Vec::new(); let mut loads = Vec::new();
for worker_id in worker_ids { for worker in workers {
loads.push(PotentialLoad { loads.push(PotentialLoad {
worker_id, worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens potential_prefill_tokens: prefill_tokens
.get(&worker_id) .get(&worker)
.copied() .copied()
.unwrap_or(isl_tokens), .unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker_id).copied().unwrap_or(0), potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
}); });
} }
...@@ -386,7 +386,7 @@ impl KvScheduler { ...@@ -386,7 +386,7 @@ impl KvScheduler {
} }
// Helper function for softmax sampling // Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 { fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) -> WorkerWithDpRank {
if logits.is_empty() { if logits.is_empty() {
panic!("Empty logits for softmax sampling"); panic!("Empty logits for softmax sampling");
} }
...@@ -474,7 +474,7 @@ impl DefaultWorkerSelector { ...@@ -474,7 +474,7 @@ impl DefaultWorkerSelector {
impl WorkerSelector for DefaultWorkerSelector { impl WorkerSelector for DefaultWorkerSelector {
fn select_worker( fn select_worker(
&self, &self,
workers: &HashMap<i64, Option<ModelRuntimeConfig>>, workers: &HashMap<WorkerId, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: u32, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> { ) -> Result<WorkerSelectionResult, KvSchedulerError> {
...@@ -494,17 +494,28 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -494,17 +494,28 @@ impl WorkerSelector for DefaultWorkerSelector {
let mut worker_logits = HashMap::new(); let mut worker_logits = HashMap::new();
let mut max_logit = f64::NEG_INFINITY; let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker // Calculate logits for each worker with dp_rank
for worker_id in workers.keys() { // Outer loop: iterate over all workers from runtime config
let overlap = *overlaps.get(worker_id).unwrap_or(&0); // Inner loop: iterate over all dp_ranks for each worker
for (worker_id, config) in workers.iter() {
// Get data_parallel_size from runtime config
// data_parallel_size defaults to 1 in ModelRuntimeConfig
let data_parallel_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1); // Fallback if config is None
// Iterate over all dp_ranks for this worker
for dp_rank in 0..data_parallel_size {
let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
// Get overlap for this worker (defaults to 0 if not in overlaps)
let overlap = *overlaps.get(&worker).unwrap_or(&0);
// this is the number of prefill tokens the worker would have if the request were scheduled there // this is the number of prefill tokens the worker would have if the request were scheduled there
let prefill_token = *prefill_tokens.get(worker_id).unwrap_or(&isl); let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64); let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
// this is the number of decode blocks the worker would have if the request were scheduled there // this is the number of decode blocks the worker would have if the request were scheduled there
let decode_block = *decode_blocks let decode_block = *decode_blocks
.get(worker_id) .get(&worker)
.unwrap_or(&(potential_prefill_block.floor() as usize)) .unwrap_or(&(potential_prefill_block.floor() as usize))
as f64; as f64;
...@@ -519,14 +530,17 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -519,14 +530,17 @@ impl WorkerSelector for DefaultWorkerSelector {
let logit = overlap_weight * potential_prefill_block + decode_block; let logit = overlap_weight * potential_prefill_block + decode_block;
max_logit = max_logit.max(logit); max_logit = max_logit.max(logit);
worker_logits.insert(*worker_id, logit); worker_logits.insert(worker, logit);
tracing::info!( tracing::info!(
"Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \ "Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \ = {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}" = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
worker.dp_rank
); );
} }
}
// Use softmax sampling to select worker // Use softmax sampling to select worker
// Use override if provided, otherwise use default config // Use override if provided, otherwise use default config
...@@ -535,29 +549,32 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -535,29 +549,32 @@ impl WorkerSelector for DefaultWorkerSelector {
.as_ref() .as_ref()
.and_then(|cfg| cfg.router_temperature) .and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature); .unwrap_or(self.kv_router_config.router_temperature);
let best_worker_id = softmax_sample(&worker_logits, temperature); let best_worker = softmax_sample(&worker_logits, temperature);
let best_logit = worker_logits[&best_worker_id]; let best_logit = worker_logits[&best_worker];
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
let best_overlap = *overlaps.get(&best_worker_id).unwrap_or(&0); // this is a runtime config set on a per worker basis, not per dp-rank
let total_blocks_info = workers let total_blocks_info = workers
.get(&best_worker_id) .get(&best_worker.worker_id)
.and_then(|cfg| cfg.as_ref()) .and_then(|cfg| cfg.as_ref())
.and_then(|cfg| cfg.total_kv_blocks) .and_then(|cfg| cfg.total_kv_blocks)
.map(|blocks| format!(", total blocks: {}", blocks)) .map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default(); .unwrap_or_default();
tracing::info!( tracing::info!(
"Selected worker: {}, logit: {:.3}, cached blocks: {}{}", "Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}{}",
best_worker_id, best_worker.worker_id,
best_worker.dp_rank,
best_logit, best_logit,
best_overlap, best_overlap,
total_blocks_info total_blocks_info
); );
Ok(WorkerSelectionResult { Ok(WorkerSelectionResult {
worker_id: best_worker_id, worker: best_worker,
required_blocks: request_blocks as u64, required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0), overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
}) })
} }
} }
...@@ -570,54 +587,61 @@ mod tests { ...@@ -570,54 +587,61 @@ mod tests {
fn test_softmax_sample_single_key() { fn test_softmax_sample_single_key() {
// Test that with a single key, softmax_sample always returns that key // Test that with a single key, softmax_sample always returns that key
let mut logits = HashMap::new(); let mut logits = HashMap::new();
let worker_id = 42; let worker = WorkerWithDpRank::from_worker_id(42);
logits.insert(worker_id, 0.5); // The value doesn't matter logits.insert(worker, 0.5); // The value doesn't matter
// Test with different temperatures // Test with different temperatures
for temperature in &[0.1, 1.0, 10.0] { for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature); let result = softmax_sample(&logits, *temperature);
assert_eq!(result, worker_id, "Should return the only available worker"); assert_eq!(result, worker, "Should return the only available worker");
} }
// Test with different logit values // Test with different logit values
logits.clear(); logits.clear();
logits.insert(worker_id, -100.0); // Very negative value logits.insert(worker, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker_id); assert_eq!(softmax_sample(&logits, 1.0), worker);
logits.clear(); logits.clear();
logits.insert(worker_id, 100.0); // Very positive value logits.insert(worker, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker_id); assert_eq!(softmax_sample(&logits, 1.0), worker);
logits.clear(); logits.clear();
logits.insert(worker_id, 0.0); // Zero value logits.insert(worker, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker_id); assert_eq!(softmax_sample(&logits, 1.0), worker);
} }
#[test] #[test]
fn test_softmax_sample_zero_temperature() { fn test_softmax_sample_zero_temperature() {
// Test that with temperature 0, softmax_sample returns the key with smallest logit // Test that with temperature 0, softmax_sample returns the key with smallest logit
let mut logits = HashMap::new(); let mut logits = HashMap::new();
logits.insert(1, 5.0); let worker1 = WorkerWithDpRank::from_worker_id(1);
logits.insert(2, 3.0); // This has the smallest logit let worker2 = WorkerWithDpRank::from_worker_id(2);
logits.insert(3, 7.0); let worker3 = WorkerWithDpRank::from_worker_id(3);
logits.insert(4, 3.5); let worker4 = WorkerWithDpRank::from_worker_id(4);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); // This has the smallest logit
logits.insert(worker3, 7.0);
logits.insert(worker4, 3.5);
// With temperature 0, should always return worker 2 (smallest logit) // With temperature 0, should always return worker 2 (smallest logit)
for _ in 0..10 { for _ in 0..10 {
let result = softmax_sample(&logits, 0.0); let result = softmax_sample(&logits, 0.0);
assert_eq!( assert_eq!(
result, 2, result, worker2,
"Should return worker with smallest logit when temperature is 0" "Should return worker with smallest logit when temperature is 0"
); );
} }
// Test with negative values // Test with negative values
logits.clear(); logits.clear();
logits.insert(10, -1.0); let worker10 = WorkerWithDpRank::from_worker_id(10);
logits.insert(20, -5.0); // This has the smallest logit let worker20 = WorkerWithDpRank::from_worker_id(20);
logits.insert(30, 0.0); let worker30 = WorkerWithDpRank::from_worker_id(30);
logits.insert(worker10, -1.0);
logits.insert(worker20, -5.0); // This has the smallest logit
logits.insert(worker30, 0.0);
let result = softmax_sample(&logits, 0.0); let result = softmax_sample(&logits, 0.0);
assert_eq!(result, 20, "Should handle negative logits correctly"); assert_eq!(result, worker20, "Should handle negative logits correctly");
} }
} }
This diff is collapsed.
...@@ -23,7 +23,8 @@ use crate::{ ...@@ -23,7 +23,8 @@ use crate::{
kv_router::{ kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK, KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK,
ROUTER_SNAPSHOT_LOCK, ROUTER_SNAPSHOT_LOCK,
indexer::{DumpRequest, GetWorkersRequest, RouterEvent, WorkerId}, indexer::{DumpRequest, GetWorkersRequest, RouterEvent},
protocols::WorkerId,
}, },
}; };
......
...@@ -211,11 +211,26 @@ impl LocalModelBuilder { ...@@ -211,11 +211,26 @@ impl LocalModelBuilder {
.map(RequestTemplate::load) .map(RequestTemplate::load)
.transpose()?; .transpose()?;
// Override runtime configs with mocker engine args (applies to both paths)
if self.is_mocker
&& let Some(path) = &self.extra_engine_args
{
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.kv_cache_block_size = mocker_engine_args.block_size as u32;
self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
}
// frontend and echo engine don't need a path. // frontend and echo engine don't need a path.
if self.model_path.is_none() { if self.model_path.is_none() {
let mut card = ModelDeploymentCard::with_name_only( let mut card = ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME), self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
); );
card.kv_cache_block_size = self.kv_cache_block_size;
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take(); card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone(); card.runtime_config = self.runtime_config.clone();
...@@ -266,18 +281,6 @@ impl LocalModelBuilder { ...@@ -266,18 +281,6 @@ impl LocalModelBuilder {
card.context_length = context_length; card.context_length = context_length;
} }
// Override runtime configs with mocker engine args
if self.is_mocker
&& let Some(path) = &self.extra_engine_args
{
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
}
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take(); card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone(); card.runtime_config = self.runtime_config.clone();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment