"container/vscode:/vscode.git/clone" did not exist on "b6596c52f41a0613977923579dd13bd1296d90dc"
Unverified Commit f978f4d1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: dp rank routing (#3597)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 29f5b822
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
NUM_WORKERS=8 NUM_WORKERS=8
MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B" MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
TENSOR_PARALLEL_SIZE=1 TENSOR_PARALLEL_SIZE=1
DATA_PARALLEL_SIZE=1
USE_MOCKERS=false USE_MOCKERS=false
USE_TRTLLM=false USE_TRTLLM=false
MODE="agg" # Options: agg (default), decode, prefill MODE="agg" # Options: agg (default), decode, prefill
...@@ -28,6 +29,10 @@ while [[ $# -gt 0 ]]; do ...@@ -28,6 +29,10 @@ while [[ $# -gt 0 ]]; do
TENSOR_PARALLEL_SIZE="$2" TENSOR_PARALLEL_SIZE="$2"
shift 2 shift 2
;; ;;
--data-parallel-size)
DATA_PARALLEL_SIZE="$2"
shift 2
;;
--mockers) --mockers)
USE_MOCKERS=true USE_MOCKERS=true
shift shift
...@@ -114,13 +119,19 @@ if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt ...@@ -114,13 +119,19 @@ if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt
exit 1 exit 1
fi fi
if ! [[ "$DATA_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$DATA_PARALLEL_SIZE" -lt 1 ]; then
echo "Error: DATA_PARALLEL_SIZE must be a positive integer"
exit 1
fi
if ! [[ "$BASE_GPU_OFFSET" =~ ^[0-9]+$ ]]; then if ! [[ "$BASE_GPU_OFFSET" =~ ^[0-9]+$ ]]; then
echo "Error: BASE_GPU_OFFSET must be a non-negative integer" echo "Error: BASE_GPU_OFFSET must be a non-negative integer"
exit 1 exit 1
fi fi
# Calculate total GPUs needed # Calculate total GPUs needed (TP * DP per worker)
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE)) GPUS_PER_WORKER=$((TENSOR_PARALLEL_SIZE * DATA_PARALLEL_SIZE))
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * GPUS_PER_WORKER))
LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1)) LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1))
echo "Configuration:" echo "Configuration:"
if [ "$USE_MOCKERS" = true ]; then if [ "$USE_MOCKERS" = true ]; then
...@@ -135,6 +146,8 @@ echo " Mode: $MODE" ...@@ -135,6 +146,8 @@ echo " Mode: $MODE"
echo " Workers: $NUM_WORKERS" echo " Workers: $NUM_WORKERS"
echo " Model: $MODEL_PATH" echo " Model: $MODEL_PATH"
echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE" echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
echo " Data Parallel Size: $DATA_PARALLEL_SIZE"
echo " GPUs per worker: $GPUS_PER_WORKER"
echo " Total GPUs needed: $TOTAL_GPUS_NEEDED" echo " Total GPUs needed: $TOTAL_GPUS_NEEDED"
echo " GPU Range: $BASE_GPU_OFFSET-$LAST_GPU" echo " GPU Range: $BASE_GPU_OFFSET-$LAST_GPU"
echo " Engine args: ${EXTRA_ARGS[*]}" echo " Engine args: ${EXTRA_ARGS[*]}"
...@@ -155,14 +168,16 @@ echo "Starting $NUM_WORKERS $MODE workers..." ...@@ -155,14 +168,16 @@ echo "Starting $NUM_WORKERS $MODE workers..."
for i in $(seq 1 $NUM_WORKERS); do for i in $(seq 1 $NUM_WORKERS); do
{ {
echo "[${MODE^} Worker-$i] Starting..." MODE_CAPITALIZED=$(echo "$MODE" | sed 's/\(.\)/\U\1/')
echo "[$MODE_CAPITALIZED Worker-$i] Starting..."
# Calculate GPU indices for this worker (with base offset) # Calculate GPU indices for this worker (with base offset)
START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * TENSOR_PARALLEL_SIZE )) # Each worker needs TP * DP GPUs
END_GPU=$(( START_GPU + TENSOR_PARALLEL_SIZE - 1 )) START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * GPUS_PER_WORKER ))
END_GPU=$(( START_GPU + GPUS_PER_WORKER - 1 ))
# Build CUDA_VISIBLE_DEVICES string # Build CUDA_VISIBLE_DEVICES string for all GPUs (TP * DP)
if [ "$TENSOR_PARALLEL_SIZE" -eq 1 ]; then if [ "$GPUS_PER_WORKER" -eq 1 ]; then
GPU_DEVICES="$START_GPU" GPU_DEVICES="$START_GPU"
else else
GPU_DEVICES="" GPU_DEVICES=""
...@@ -177,12 +192,17 @@ for i in $(seq 1 $NUM_WORKERS); do ...@@ -177,12 +192,17 @@ for i in $(seq 1 $NUM_WORKERS); do
if [ "$USE_MOCKERS" = true ]; then if [ "$USE_MOCKERS" = true ]; then
# Run mocker engine (no GPU assignment needed) # Run mocker engine (no GPU assignment needed)
exec python -m dynamo.mocker \ MOCKER_ARGS=()
--model-path "$MODEL_PATH" \ MOCKER_ARGS+=("--model-path" "$MODEL_PATH")
--endpoint dyn://test.mocker.generate \ MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
"${EXTRA_ARGS[@]}" if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
MOCKER_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
MOCKER_ARGS+=("${EXTRA_ARGS[@]}")
exec python -m dynamo.mocker "${MOCKER_ARGS[@]}"
elif [ "$USE_TRTLLM" = true ]; then elif [ "$USE_TRTLLM" = true ]; then
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES" echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
# Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization # Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization
TRTLLM_ARGS=() TRTLLM_ARGS=()
TRTLLM_ARGS+=("--model-path" "$MODEL_PATH") TRTLLM_ARGS+=("--model-path" "$MODEL_PATH")
...@@ -195,11 +215,14 @@ for i in $(seq 1 $NUM_WORKERS); do ...@@ -195,11 +215,14 @@ for i in $(seq 1 $NUM_WORKERS); do
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \ exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \
"${TRTLLM_ARGS[@]}" "${TRTLLM_ARGS[@]}"
else else
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES" echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
# Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing # Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing
VLLM_ARGS=() VLLM_ARGS=()
VLLM_ARGS+=("--model" "$MODEL_PATH") VLLM_ARGS+=("--model" "$MODEL_PATH")
VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE") VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
VLLM_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
if [ "$MODE" = "prefill" ]; then if [ "$MODE" = "prefill" ]; then
VLLM_ARGS+=("--is-prefill-worker") VLLM_ARGS+=("--is-prefill-worker")
fi fi
......
...@@ -99,6 +99,7 @@ class StandaloneRouterHandler: ...@@ -99,6 +99,7 @@ class StandaloneRouterHandler:
"eos_token_ids": request.get("eos_token_ids", []), "eos_token_ids": request.get("eos_token_ids", []),
"annotations": request.get("annotations", []), "annotations": request.get("annotations", []),
"disaggregated_params": request.get("disaggregated_params"), "disaggregated_params": request.get("disaggregated_params"),
"dp_rank": request.get("dp_rank"),
"extra_args": request.get("extra_args", {}), "extra_args": request.get("extra_args", {}),
} }
......
...@@ -33,7 +33,7 @@ class BaseWorkerHandler(ABC): ...@@ -33,7 +33,7 @@ class BaseWorkerHandler(ABC):
self.component = component self.component = component
self.engine_client = engine self.engine_client = engine
self.default_sampling_params = default_sampling_params self.default_sampling_params = default_sampling_params
self.kv_publisher = None self.kv_publishers = None
self.engine_monitor = VllmEngineMonitor(runtime, engine) self.engine_monitor = VllmEngineMonitor(runtime, engine)
@abstractmethod @abstractmethod
...@@ -81,9 +81,16 @@ class BaseWorkerHandler(ABC): ...@@ -81,9 +81,16 @@ class BaseWorkerHandler(ABC):
"""Override in subclasses if cleanup is needed.""" """Override in subclasses if cleanup is needed."""
pass pass
async def generate_tokens(self, prompt, sampling_params, request_id): async def generate_tokens(
self, prompt, sampling_params, request_id, data_parallel_rank=None
):
try: try:
gen = self.engine_client.generate(prompt, sampling_params, request_id) gen = self.engine_client.generate(
prompt,
sampling_params,
request_id,
data_parallel_rank=data_parallel_rank,
)
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
try: try:
...@@ -211,10 +218,12 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -211,10 +218,12 @@ class DecodeWorkerHandler(BaseWorkerHandler):
return return
logger.warning(f"Prefill error: {e}, falling back to local prefill") logger.warning(f"Prefill error: {e}, falling back to local prefill")
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id): async with self._abort_monitor(context, request_id):
try: try:
async for tok in self.generate_tokens( async for tok in self.generate_tokens(
prompt, sampling_params, request_id prompt, sampling_params, request_id, data_parallel_rank=dp_rank
): ):
yield tok yield tok
except EngineDeadError as e: except EngineDeadError as e:
...@@ -241,9 +250,13 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -241,9 +250,13 @@ class PrefillWorkerHandler(BaseWorkerHandler):
sampling_params_dict = extra_args.get("sampling_params", {}) sampling_params_dict = extra_args.get("sampling_params", {})
sampling_params = msgspec.convert(sampling_params_dict, SamplingParams) sampling_params = msgspec.convert(sampling_params_dict, SamplingParams)
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id, is_prefill=True): async with self._abort_monitor(context, request_id, is_prefill=True):
try: try:
gen = self.engine_client.generate(prompt, sampling_params, request_id) gen = self.engine_client.generate(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
)
except EngineDeadError as e: except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}") logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.") logger.warning("Initiating Dynamo Runtime shutdown.")
......
...@@ -5,7 +5,6 @@ import asyncio ...@@ -5,7 +5,6 @@ import asyncio
import logging import logging
import os import os
import signal import signal
from typing import Optional
import uvloop import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher from vllm.distributed.kv_events import ZmqEventPublisher
...@@ -107,33 +106,41 @@ def setup_kv_event_publisher( ...@@ -107,33 +106,41 @@ def setup_kv_event_publisher(
component, component,
generate_endpoint, generate_endpoint,
vllm_config, vllm_config,
) -> Optional[ZmqKvEventPublisher]: ):
""" """
Set up KV event publisher for prefix caching if enabled. Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Returns: Returns:
ZmqKvEventPublisher if prefix caching is enabled, None otherwise. List of ZmqKvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
""" """
if not config.engine_args.enable_prefix_caching: if not config.engine_args.enable_prefix_caching:
return None return None
# TODO: We start off with a valid endpoint, then we increment it by dp_rank # Get data_parallel_size to create publishers for all dp_ranks
# May no longer be valid. Lets remove the increment behavior from vLLM and here data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( kv_publishers = []
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=config.engine_args.data_parallel_rank or 0, for dp_rank in range(data_parallel_size):
).replace("*", "127.0.0.1") # Each dp_rank publishes to a different port
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
zmq_config = ZmqKvEventPublisherConfig( config.engine_args.kv_events_config.endpoint,
worker_id=generate_endpoint.lease_id(), data_parallel_rank=dp_rank,
kv_block_size=vllm_config.cache_config.block_size, ).replace("*", "127.0.0.1")
zmq_endpoint=zmq_endpoint,
) zmq_config = ZmqKvEventPublisherConfig(
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config) worker_id=generate_endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
kv_publishers.append(kv_publisher)
logger.info(f"Worker reading KV events from {zmq_endpoint}") logger.info(
f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
)
return kv_publisher return kv_publishers if kv_publishers else None
def setup_vllm_engine(config, stat_logger=None): def setup_vllm_engine(config, stat_logger=None):
...@@ -200,12 +207,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -200,12 +207,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
runtime, component, engine_client, default_sampling_params runtime, component, engine_client, default_sampling_params
) )
# Set up KV event publisher for prefix caching if enabled # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
kv_publisher = setup_kv_event_publisher( kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config config, component, generate_endpoint, vllm_config
) )
if kv_publisher: if kv_publishers:
handler.kv_publisher = kv_publisher handler.kv_publishers = kv_publishers
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict() health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
...@@ -285,12 +292,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -285,12 +292,12 @@ async def init(runtime: DistributedRuntime, config: Config):
prefill_router_client, prefill_router_client,
) )
# Set up KV event publisher for prefix caching if enabled # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
kv_publisher = setup_kv_event_publisher( kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config config, component, generate_endpoint, vllm_config
) )
if kv_publisher: if kv_publishers:
handler.kv_publisher = kv_publisher handler.kv_publishers = kv_publishers
if config.engine_args.disable_log_stats is False: if config.engine_args.disable_log_stats is False:
from prometheus_client import REGISTRY from prometheus_client import REGISTRY
...@@ -311,6 +318,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -311,6 +318,12 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.tool_call_parser = config.tool_call_parser runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser runtime_config.reasoning_parser = config.reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(
vllm_config.parallel_config, "data_parallel_size", 1
)
runtime_config.data_parallel_size = data_parallel_size
await register_llm( await register_llm(
ModelInput.Tokens, ModelInput.Tokens,
ModelType.Chat | ModelType.Completions, ModelType.Chat | ModelType.Completions,
......
...@@ -322,7 +322,11 @@ The `KvPushRouter` provides the following methods: ...@@ -322,7 +322,11 @@ The `KvPushRouter` provides the following methods:
- **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management. - **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.
- **`best_worker_id(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, overlap_blocks)`. - **`best_worker(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, dp_rank, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
- **`best_worker_id(token_ids, router_config_override=None, request_id=None)`**: **[DEPRECATED - use `best_worker()` instead]** Query which worker would be selected for given tokens. Returns `(worker_id, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state - Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking - With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
......
...@@ -207,6 +207,7 @@ fn kv_event_create_stored_from_parts( ...@@ -207,6 +207,7 @@ fn kv_event_create_stored_from_parts(
parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash), parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
}), }),
event_id: kv_params.event_id, event_id: kv_params.event_id,
dp_rank: 0,
} }
} }
...@@ -224,6 +225,7 @@ fn kv_event_create_removed_from_parts( ...@@ -224,6 +225,7 @@ fn kv_event_create_removed_from_parts(
KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: 0,
} }
} }
......
...@@ -101,7 +101,7 @@ impl WorkerMetricsPublisher { ...@@ -101,7 +101,7 @@ impl WorkerMetricsPublisher {
#[derive(Clone)] #[derive(Clone)]
pub struct ZmqKvEventPublisherConfig { pub struct ZmqKvEventPublisherConfig {
#[pyo3(get, set)] #[pyo3(get, set)]
pub worker_id: i64, pub worker_id: WorkerId,
#[pyo3(get, set)] #[pyo3(get, set)]
pub kv_block_size: usize, pub kv_block_size: usize,
#[pyo3(get, set)] #[pyo3(get, set)]
...@@ -120,7 +120,7 @@ impl ZmqKvEventPublisherConfig { ...@@ -120,7 +120,7 @@ impl ZmqKvEventPublisherConfig {
zmq_topic = "".to_string() zmq_topic = "".to_string()
))] ))]
pub fn new( pub fn new(
worker_id: i64, worker_id: WorkerId,
kv_block_size: usize, kv_block_size: usize,
zmq_endpoint: String, zmq_endpoint: String,
zmq_topic: String, zmq_topic: String,
...@@ -234,13 +234,20 @@ impl Drop for ZmqKvEventListener { ...@@ -234,13 +234,20 @@ impl Drop for ZmqKvEventListener {
pub(crate) struct KvEventPublisher { pub(crate) struct KvEventPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>, inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
kv_block_size: usize, kv_block_size: usize,
dp_rank: DpRank,
warning_count: Arc<AtomicU32>, warning_count: Arc<AtomicU32>,
} }
#[pymethods] #[pymethods]
impl KvEventPublisher { impl KvEventPublisher {
#[new] #[new]
fn new(component: Component, worker_id: i64, kv_block_size: usize) -> PyResult<Self> { #[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))]
fn new(
component: Component,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
) -> PyResult<Self> {
if kv_block_size == 0 { if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
} }
...@@ -256,6 +263,7 @@ impl KvEventPublisher { ...@@ -256,6 +263,7 @@ impl KvEventPublisher {
Ok(Self { Ok(Self {
inner: inner.into(), inner: inner.into(),
kv_block_size, kv_block_size,
dp_rank,
warning_count: Arc::new(AtomicU32::new(0)), warning_count: Arc::new(AtomicU32::new(0)),
}) })
} }
...@@ -286,6 +294,7 @@ impl KvEventPublisher { ...@@ -286,6 +294,7 @@ impl KvEventPublisher {
&self.warning_count, &self.warning_count,
), ),
}), }),
dp_rank: self.dp_rank,
}; };
self.inner.publish(event).map_err(to_pyerr) self.inner.publish(event).map_err(to_pyerr)
...@@ -299,6 +308,7 @@ impl KvEventPublisher { ...@@ -299,6 +308,7 @@ impl KvEventPublisher {
let event = KvCacheEvent { let event = KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: self.dp_rank,
}; };
self.inner.publish(event).map_err(to_pyerr) self.inner.publish(event).map_err(to_pyerr)
...@@ -314,8 +324,13 @@ pub(crate) struct OverlapScores { ...@@ -314,8 +324,13 @@ pub(crate) struct OverlapScores {
#[pymethods] #[pymethods]
impl OverlapScores { impl OverlapScores {
#[getter] #[getter]
fn scores(&self) -> HashMap<llm_rs::kv_router::indexer::WorkerId, u32> { fn scores(&self) -> HashMap<(i64, u32), u32> {
self.inner.scores.clone() // Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
self.inner
.scores
.iter()
.map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
.collect()
} }
#[getter] #[getter]
...@@ -361,7 +376,7 @@ impl RadixTree { ...@@ -361,7 +376,7 @@ impl RadixTree {
fn apply_event( fn apply_event(
&mut self, &mut self,
_py: Python, _py: Python,
worker_id: i64, worker_id: WorkerId,
kv_cache_event_bytes: &[u8], kv_cache_event_bytes: &[u8],
) -> PyResult<()> { ) -> PyResult<()> {
let kv_cache_event: llm_rs::kv_router::protocols::KvCacheEvent = let kv_cache_event: llm_rs::kv_router::protocols::KvCacheEvent =
...@@ -377,12 +392,12 @@ impl RadixTree { ...@@ -377,12 +392,12 @@ impl RadixTree {
Ok(()) Ok(())
} }
fn remove_worker(&mut self, _py: Python, worker_id: i64) -> PyResult<()> { fn remove_worker(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.remove_worker(worker_id); self.inner.remove_worker(worker_id);
Ok(()) Ok(())
} }
fn clear_all_blocks(&mut self, _py: Python, worker_id: i64) -> PyResult<()> { fn clear_all_blocks(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.clear_all_blocks(worker_id); self.inner.clear_all_blocks(worker_id);
Ok(()) Ok(())
} }
...@@ -517,16 +532,19 @@ impl ApproxKvIndexer { ...@@ -517,16 +532,19 @@ impl ApproxKvIndexer {
}) })
} }
#[pyo3(signature = (tokens, worker_id, dp_rank=0))]
fn process_routing_decision_for_request<'p>( fn process_routing_decision_for_request<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
tokens: Vec<u32>, tokens: Vec<u32>,
worker_id: i64, worker_id: WorkerId,
dp_rank: DpRank,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone(); let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker = llm_rs::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank);
indexer indexer
.process_routing_decision_for_request(tokens.as_slice(), worker_id) .process_routing_decision_for_request(tokens.as_slice(), worker)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(()) Ok(())
...@@ -538,7 +556,7 @@ impl ApproxKvIndexer { ...@@ -538,7 +556,7 @@ impl ApproxKvIndexer {
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct EndpointKvMetrics { pub(crate) struct EndpointKvMetrics {
#[pyo3(get, set)] #[pyo3(get, set)]
pub worker_id: i64, pub worker_id: WorkerId,
#[pyo3(get, set)] #[pyo3(get, set)]
pub request_active_slots: u64, pub request_active_slots: u64,
#[pyo3(get, set)] #[pyo3(get, set)]
...@@ -784,7 +802,7 @@ impl WorkerStats { ...@@ -784,7 +802,7 @@ impl WorkerStats {
request_active_slots: u64, request_active_slots: u64,
request_total_slots: u64, request_total_slots: u64,
num_requests_waiting: u64, num_requests_waiting: u64,
data_parallel_rank: Option<u32>, data_parallel_rank: Option<DpRank>,
) -> Self { ) -> Self {
Self(RsWorkerStats { Self(RsWorkerStats {
data_parallel_rank, data_parallel_rank,
...@@ -961,7 +979,7 @@ impl KvPushRouter { ...@@ -961,7 +979,7 @@ impl KvPushRouter {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, extra_args=None))] #[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None))]
fn generate<'p>( fn generate<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
...@@ -971,7 +989,8 @@ impl KvPushRouter { ...@@ -971,7 +989,8 @@ impl KvPushRouter {
sampling_options: Option<PyObject>, sampling_options: Option<PyObject>,
output_options: Option<PyObject>, output_options: Option<PyObject>,
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
worker_id: Option<i64>, worker_id: Option<WorkerId>,
dp_rank: Option<DpRank>,
extra_args: Option<PyObject>, extra_args: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults // Depythonize the options with defaults
...@@ -1027,6 +1046,7 @@ impl KvPushRouter { ...@@ -1027,6 +1046,7 @@ impl KvPushRouter {
.sampling_options(sampling_options) .sampling_options(sampling_options)
.output_options(output_options) .output_options(output_options)
.router_config_override(router_config_override) .router_config_override(router_config_override)
.dp_rank(dp_rank)
.extra_args(extra_args); .extra_args(extra_args);
// Set backend_instance_id if worker_id is provided // Set backend_instance_id if worker_id is provided
...@@ -1053,6 +1073,43 @@ impl KvPushRouter { ...@@ -1053,6 +1073,43 @@ impl KvPushRouter {
Self::process_request_to_stream(py, self.inner.clone(), request) Self::process_request_to_stream(py, self.inner.clone(), request)
} }
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
fn best_worker<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
router_config_override: Option<PyObject>,
request_id: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
depythonize(obj.bind(py)).map_err(to_pyerr)?;
Ok::<_, PyErr>(Some(override_config))
})?
} else {
None
};
let chooser = self.inner.chooser.clone();
let update_states = request_id.is_some();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (best_worker, overlap_blocks) = chooser
.find_best_match(
request_id.as_deref(),
&token_ids,
router_config_override.as_ref(),
update_states,
)
.await
.map_err(to_pyerr)?;
Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
})
}
/// Deprecated: Use `best_worker()` instead which returns (worker_id, dp_rank, overlap_blocks)
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))] #[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
fn best_worker_id<'p>( fn best_worker_id<'p>(
&self, &self,
...@@ -1061,6 +1118,16 @@ impl KvPushRouter { ...@@ -1061,6 +1118,16 @@ impl KvPushRouter {
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
request_id: Option<String>, request_id: Option<String>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Issue deprecation warning
let warnings = py.import("warnings")?;
warnings.call_method1(
"warn",
(
"best_worker_id() is deprecated. Use best_worker() instead which returns (worker_id, dp_rank, overlap_blocks)",
py.get_type::<pyo3::exceptions::PyDeprecationWarning>(),
),
)?;
let router_config_override = if let Some(obj) = router_config_override { let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| { Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride = let override_config: llm_rs::kv_router::RouterConfigOverride =
...@@ -1075,7 +1142,7 @@ impl KvPushRouter { ...@@ -1075,7 +1142,7 @@ impl KvPushRouter {
let update_states = request_id.is_some(); let update_states = request_id.is_some();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = chooser let (best_worker, overlap_blocks) = chooser
.find_best_match( .find_best_match(
request_id.as_deref(), request_id.as_deref(),
&token_ids, &token_ids,
...@@ -1085,8 +1152,8 @@ impl KvPushRouter { ...@@ -1085,8 +1152,8 @@ impl KvPushRouter {
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
// Return a tuple of (worker_id, overlap_blocks) // Return only worker_id and overlap_blocks for backward compatibility
Ok((worker_id, overlap_blocks)) Ok((best_worker.worker_id, overlap_blocks))
}) })
} }
...@@ -1130,6 +1197,7 @@ impl KvPushRouter { ...@@ -1130,6 +1197,7 @@ impl KvPushRouter {
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
// Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
// Use pythonize to convert Vec<PotentialLoad> to Python list of dicts // Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
Python::with_gil(|py| { Python::with_gil(|py| {
pythonize(py, &loads) pythonize(py, &loads)
......
...@@ -44,6 +44,11 @@ impl ModelRuntimeConfig { ...@@ -44,6 +44,11 @@ impl ModelRuntimeConfig {
self.inner.reasoning_parser = reasoning_parser; self.inner.reasoning_parser = reasoning_parser;
} }
#[setter]
fn set_data_parallel_size(&mut self, data_parallel_size: u32) {
self.inner.data_parallel_size = data_parallel_size;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner self.inner
......
...@@ -778,16 +778,21 @@ class KvEventPublisher: ...@@ -778,16 +778,21 @@ class KvEventPublisher:
... ...
def __init__( def __init__(
self, component: Component, worker_id: int, kv_block_size: int self, component: Component, worker_id: int, kv_block_size: int, dp_rank: int = 0
) -> None: ) -> None:
""" """
Create a `KvEventPublisher` object Create a `KvEventPublisher` object
Args:
component: The component to publish events for
worker_id: The worker ID
kv_block_size: The KV block size (must be > 0)
dp_rank: The data parallel rank (defaults to 0)
""" """
def publish_stored( def publish_stored(
self, self,
event_id, event_id: int,
int,
token_ids: List[int], token_ids: List[int],
num_block_tokens: List[int], num_block_tokens: List[int],
block_hashes: List[int], block_hashes: List[int],
...@@ -796,12 +801,24 @@ class KvEventPublisher: ...@@ -796,12 +801,24 @@ class KvEventPublisher:
) -> None: ) -> None:
""" """
Publish a KV stored event. Publish a KV stored event.
Args:
event_id: The event ID
token_ids: List of token IDs
num_block_tokens: Number of tokens per block
block_hashes: List of block hashes (signed 64-bit integers)
lora_id: The LoRA ID
parent_hash: Optional parent hash (signed 64-bit integer)
""" """
... ...
def publish_removed(self, event_id, int, block_hashes: List[int]) -> None: def publish_removed(self, event_id: int, block_hashes: List[int]) -> None:
""" """
Publish a KV removed event. Publish a KV removed event.
Args:
event_id: The event ID
block_hashes: List of block hashes to remove (signed 64-bit integers)
""" """
... ...
...@@ -1199,6 +1216,7 @@ class KvPushRouter: ...@@ -1199,6 +1216,7 @@ class KvPushRouter:
output_options: Optional[JsonLike] = None, output_options: Optional[JsonLike] = None,
router_config_override: Optional[JsonLike] = None, router_config_override: Optional[JsonLike] = None,
worker_id: Optional[int] = None, worker_id: Optional[int] = None,
dp_rank: Optional[int] = None,
) -> AsyncIterator[JsonLike]: ) -> AsyncIterator[JsonLike]:
""" """
Generate text using the KV-aware router. Generate text using the KV-aware router.
...@@ -1213,6 +1231,10 @@ class KvPushRouter: ...@@ -1213,6 +1231,10 @@ class KvPushRouter:
worker_id: Optional worker ID to route to directly. If set, the request worker_id: Optional worker ID to route to directly. If set, the request
will be sent to this specific worker and router states will be will be sent to this specific worker and router states will be
updated accordingly. updated accordingly.
dp_rank: Optional data parallel rank to route to. If set along with worker_id,
the request will be routed to the specific (worker_id, dp_rank) pair.
If only dp_rank is set, the router will select the best worker but
force routing to the specified dp_rank.
Returns: Returns:
An async iterator yielding generation responses An async iterator yielding generation responses
...@@ -1220,10 +1242,36 @@ class KvPushRouter: ...@@ -1220,10 +1242,36 @@ class KvPushRouter:
Note: Note:
- If worker_id is set, the request bypasses KV matching and routes directly - If worker_id is set, the request bypasses KV matching and routes directly
to the specified worker while still updating router states. to the specified worker while still updating router states.
- dp_rank allows targeting a specific data parallel replica when workers have
multiple replicas (data_parallel_size > 1).
- This is different from query_instance_id which doesn't route the request. - This is different from query_instance_id which doesn't route the request.
""" """
... ...
async def best_worker(
self,
token_ids: List[int],
router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None,
) -> Tuple[int, int, int]:
"""
Find the best matching worker for the given tokens.
Args:
token_ids: List of token IDs to find matches for
router_config_override: Optional router configuration override
request_id: Optional request ID. If provided, router states will be updated
to track this request (active blocks, lifecycle events). If not
provided, this is a query-only operation that doesn't affect state.
Returns:
A tuple of (worker_id, dp_rank, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- dp_rank: The data parallel rank of the selected worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def best_worker_id( async def best_worker_id(
self, self,
token_ids: List[int], token_ids: List[int],
...@@ -1231,6 +1279,8 @@ class KvPushRouter: ...@@ -1231,6 +1279,8 @@ class KvPushRouter:
request_id: Optional[str] = None, request_id: Optional[str] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
""" """
[DEPRECATED] Use best_worker() instead which returns (worker_id, dp_rank, overlap_blocks).
Find the best matching worker for the given tokens. Find the best matching worker for the given tokens.
Args: Args:
...@@ -1244,6 +1294,9 @@ class KvPushRouter: ...@@ -1244,6 +1294,9 @@ class KvPushRouter:
A tuple of (worker_id, overlap_blocks) where: A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker - worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found - overlap_blocks: The number of overlapping blocks found
.. deprecated::
Use :meth:`best_worker` instead which also returns dp_rank.
""" """
... ...
...@@ -1260,8 +1313,13 @@ class KvPushRouter: ...@@ -1260,8 +1313,13 @@ class KvPushRouter:
Returns: Returns:
A list of dictionaries, each containing: A list of dictionaries, each containing:
- worker_id: The worker ID - worker_id: The worker ID
- dp_rank: The data parallel rank
- potential_prefill_tokens: Number of tokens that would need prefill - potential_prefill_tokens: Number of tokens that would need prefill
- potential_decode_blocks: Number of blocks currently in decode phase - potential_decode_blocks: Number of blocks currently in decode phase
Note:
Each (worker_id, dp_rank) pair is returned as a separate entry.
If you need aggregated loads per worker_id, sum the values manually.
""" """
... ...
...@@ -1287,7 +1345,7 @@ class KvPushRouter: ...@@ -1287,7 +1345,7 @@ class KvPushRouter:
Note: Note:
This is typically called automatically by the router when using the This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using `generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing. `best_worker()` with `request_id` for custom routing.
""" """
... ...
...@@ -1304,7 +1362,7 @@ class KvPushRouter: ...@@ -1304,7 +1362,7 @@ class KvPushRouter:
Note: Note:
This is typically called automatically by the router when using the This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using `generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing. `best_worker()` with `request_id` for custom routing.
""" """
... ...
......
...@@ -85,17 +85,21 @@ async def test_radix_tree_binding(distributed_runtime): ...@@ -85,17 +85,21 @@ async def test_radix_tree_binding(distributed_runtime):
overlap_scores = radix_tree.find_matches([0]) overlap_scores = radix_tree.find_matches([0])
# Verify the results # Verify the results
# Note: scores is now Dict[(worker_id, dp_rank), score]
assert overlap_scores.scores is not None assert overlap_scores.scores is not None
assert ( assert (
len(overlap_scores.scores) == 1 len(overlap_scores.scores) == 1
), f"Expected 1 worker in scores, got {len(overlap_scores.scores)}" ), f"Expected 1 worker in scores, got {len(overlap_scores.scores)}"
assert worker_id in overlap_scores.scores, f"Worker {worker_id} not found in scores" worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert ( assert (
overlap_scores.scores[worker_id] == 1 worker_key in overlap_scores.scores
), f"Expected score 1 for worker {worker_id}, got {overlap_scores.scores[worker_id]}" ), f"Worker {worker_key} not found in scores"
assert (
overlap_scores.scores[worker_key] == 1
), f"Expected score 1 for worker {worker_key}, got {overlap_scores.scores[worker_key]}"
print( print(
f"✓ RadixTree test passed: worker {worker_id} has score {overlap_scores.scores[worker_id]}" f"✓ RadixTree test passed: worker {worker_key} has score {overlap_scores.scores[worker_key]}"
) )
...@@ -130,24 +134,25 @@ async def test_event_handler(distributed_runtime): ...@@ -130,24 +134,25 @@ async def test_event_handler(distributed_runtime):
event_publisher.store_event(test_token, lora_id) event_publisher.store_event(test_token, lora_id)
# wait for the event to be processed as it is sent asynchronously # wait for the event to be processed as it is sent asynchronously
# Retry loop for CI environments where processing may take longer # Retry loop for CI environments where processing may take longer
worker_key = (worker_id, 0) # (worker_id, dp_rank)
for retry in range(10): # Try up to 10 times for retry in range(10): # Try up to 10 times
await asyncio.sleep(0.5) # Wait 500ms between retries await asyncio.sleep(0.5) # Wait 500ms between retries
scores = await indexer.find_matches_for_request(test_token, lora_id) scores = await indexer.find_matches_for_request(test_token, lora_id)
if ( if (
scores.scores scores.scores
and worker_id in scores.scores and worker_key in scores.scores
and scores.scores[worker_id] == 1 and scores.scores[worker_key] == 1
): ):
break break
if retry == 9: # Last iteration if retry == 9: # Last iteration
# Provide detailed error message for debugging # Provide detailed error message for debugging
assert scores.scores, f"No scores found after {(retry+1)*0.5}s" assert scores.scores, f"No scores found after {(retry+1)*0.5}s"
assert ( assert (
worker_id in scores.scores worker_key in scores.scores
), f"Worker {worker_id} not in scores after {(retry+1)*0.5}s" ), f"Worker {worker_key} not in scores after {(retry+1)*0.5}s"
assert ( assert (
scores.scores[worker_id] == 1 scores.scores[worker_key] == 1
), f"Expected score 1, got {scores.scores.get(worker_id)} after {(retry+1)*0.5}s" ), f"Expected score 1, got {scores.scores.get(worker_key)} after {(retry+1)*0.5}s"
# remove event # remove event
event_publisher.remove_event() event_publisher.remove_event()
...@@ -185,8 +190,9 @@ async def test_approx_kv_indexer(distributed_runtime): ...@@ -185,8 +190,9 @@ async def test_approx_kv_indexer(distributed_runtime):
scores = await indexer.find_matches_for_request(tokens) scores = await indexer.find_matches_for_request(tokens)
assert scores.scores assert scores.scores
assert worker_id in scores.scores worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert scores.scores[worker_id] == 2 assert worker_key in scores.scores
assert scores.scores[worker_key] == 2
class EventPublisher: class EventPublisher:
...@@ -281,7 +287,7 @@ async def metrics_publisher_task(kv_listener, expected_metrics): ...@@ -281,7 +287,7 @@ async def metrics_publisher_task(kv_listener, expected_metrics):
expected_metrics["request_active_slots"], expected_metrics["request_active_slots"],
expected_metrics["request_total_slots"], expected_metrics["request_total_slots"],
expected_metrics["num_requests_waiting"], expected_metrics["num_requests_waiting"],
None, 0, # data_parallel_rank (0 = DP not enabled)
) )
kv_stats = KvStats( kv_stats = KvStats(
......
...@@ -38,7 +38,9 @@ use crate::{ ...@@ -38,7 +38,9 @@ use crate::{
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent, KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block, compute_block_hash_for_seq, compute_seq_hash_for_block,
}, },
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{
LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
subscriber::start_kv_router_background, subscriber::start_kv_router_background,
...@@ -74,7 +76,7 @@ pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock"; ...@@ -74,7 +76,7 @@ pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";
pub trait WorkerSelector { pub trait WorkerSelector {
fn select_worker( fn select_worker(
&self, &self,
workers: &HashMap<i64, Option<ModelRuntimeConfig>>, workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: u32, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>; ) -> Result<WorkerSelectionResult, KvSchedulerError>;
...@@ -316,7 +318,7 @@ impl KvRouter { ...@@ -316,7 +318,7 @@ impl KvRouter {
} }
/// Give these tokens, find the worker with the best match in it's KV cache. /// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks. /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking /// Now also takes optional context_id for request tracking
pub async fn find_best_match( pub async fn find_best_match(
&self, &self,
...@@ -324,7 +326,7 @@ impl KvRouter { ...@@ -324,7 +326,7 @@ impl KvRouter {
tokens: &[u32], tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
) -> anyhow::Result<(i64, u32)> { ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
// Validate that context_id is provided when update_states is true // Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() { if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true"); panic!("context_id must be provided if update_states is true");
...@@ -350,7 +352,7 @@ impl KvRouter { ...@@ -350,7 +352,7 @@ impl KvRouter {
(false, false) => (None, None), (false, false) => (None, None),
}; };
let best_worker_id = self let best_worker = self
.scheduler .scheduler
.schedule( .schedule(
context_id.map(|s| s.to_string()), context_id.map(|s| s.to_string()),
...@@ -364,17 +366,17 @@ impl KvRouter { ...@@ -364,17 +366,17 @@ impl KvRouter {
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer { if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer indexer
.process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap()) .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
.await .await
.unwrap(); .unwrap();
}; };
let overlap_amount = overlap_scores let overlap_amount = overlap_scores
.scores .scores
.get(&best_worker_id) .get(&best_worker)
.copied() .copied()
.unwrap_or(0); .unwrap_or(0);
Ok((best_worker_id, overlap_amount)) Ok((best_worker, overlap_amount))
} }
pub async fn add_request( pub async fn add_request(
...@@ -382,7 +384,7 @@ impl KvRouter { ...@@ -382,7 +384,7 @@ impl KvRouter {
request_id: String, request_id: String,
tokens: &[u32], tokens: &[u32],
overlap_blocks: u32, overlap_blocks: u32,
worker_id: i64, worker: WorkerWithDpRank,
) { ) {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
...@@ -397,7 +399,7 @@ impl KvRouter { ...@@ -397,7 +399,7 @@ impl KvRouter {
maybe_seq_hashes, maybe_seq_hashes,
isl_tokens, isl_tokens,
overlap_blocks, overlap_blocks,
worker_id, worker,
) )
.await; .await;
} }
...@@ -450,12 +452,13 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -450,12 +452,13 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
// Handle different request types // Handle different request types
let response = match request { let response = match request {
RouterRequest::New { tokens } => { RouterRequest::New { tokens } => {
let (worker_id, overlap_blocks) = self let (best_worker, overlap_blocks) = self
.find_best_match(Some(&context_id), &tokens, None, true) .find_best_match(Some(&context_id), &tokens, None, true)
.await?; .await?;
RouterResponse::New { RouterResponse::New {
worker_id, worker_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
overlap_blocks, overlap_blocks,
} }
} }
...@@ -523,24 +526,45 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -523,24 +526,45 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Check if this is a query_instance_id request first // Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id"); let query_instance_id = request.has_annotation("query_instance_id");
let (instance_id, overlap_amount) = if let Some(id) = request.backend_instance_id { let (instance_id, dp_rank, overlap_amount) = if let Some(id) =
// If instance_id is set, use it and manually add the request to track it request.backend_instance_id
if !query_instance_id { {
self.chooser // If instance_id is set, use it and compute actual overlap
.add_request(context_id.clone(), &request.token_ids, 0, id) let dp_rank = request.dp_rank.unwrap_or(0);
.await; if query_instance_id {
tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
);
} }
(id, 0)
// Compute actual overlap blocks by querying the indexer
let block_hashes =
compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);
self.chooser
.add_request(
context_id.clone(),
&request.token_ids,
overlap_blocks,
worker,
)
.await;
(id, dp_rank, overlap_blocks)
} else { } else {
// Otherwise, find the best match // Otherwise, find the best match
self.chooser let (best_worker, overlap_amount) = self
.chooser
.find_best_match( .find_best_match(
Some(&context_id), Some(&context_id),
&request.token_ids, &request.token_ids,
request.router_config_override.as_ref(), request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id !query_instance_id, // Don't update states if query_instance_id
) )
.await? .await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount)
}; };
// if request has the annotation "query_instance_id", // if request has the annotation "query_instance_id",
...@@ -564,6 +588,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -564,6 +588,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?; let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
......
...@@ -27,11 +27,11 @@ use crate::tokens::{SequenceHash, TokenBlockSequence}; ...@@ -27,11 +27,11 @@ use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::indexer::{ use crate::kv_router::indexer::{
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent,
WorkerId, compute_block_hash_for_seq, compute_block_hash_for_seq,
}; };
use crate::kv_router::protocols::{ use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, KvCacheStoredBlockData, LocalBlockHash, WorkerId, WorkerWithDpRank,
}; };
#[derive(Debug)] #[derive(Debug)]
...@@ -44,8 +44,8 @@ struct MatchRequest { ...@@ -44,8 +44,8 @@ struct MatchRequest {
#[derive(Debug)] #[derive(Debug)]
struct RouterResult { struct RouterResult {
/// The id of the selected worker. /// The worker (with dp_rank) that was selected.
worker_id: WorkerId, worker: WorkerWithDpRank,
/// The local hashes of the tokens sent to the worker. /// The local hashes of the tokens sent to the worker.
local_hashes: Vec<LocalBlockHash>, local_hashes: Vec<LocalBlockHash>,
...@@ -58,8 +58,8 @@ struct RouterResult { ...@@ -58,8 +58,8 @@ struct RouterResult {
struct TimerEntry { struct TimerEntry {
/// The key of the timer. /// The key of the timer.
key: ExternalSequenceBlockHash, key: ExternalSequenceBlockHash,
/// The worker id that stored this block. /// The worker (with dp_rank) that stored this block.
worker: WorkerId, worker: WorkerWithDpRank,
} }
/// A data structure to manage a collection of timers, addressable by a key. /// A data structure to manage a collection of timers, addressable by a key.
...@@ -237,10 +237,11 @@ impl ApproxKvIndexer { ...@@ -237,10 +237,11 @@ impl ApproxKvIndexer {
event_id += 1; event_id += 1;
let event = RouterEvent::new( let event = RouterEvent::new(
result.worker_id, result.worker.worker_id,
KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: stored_event, data: stored_event,
dp_rank: result.worker.dp_rank,
} }
); );
...@@ -248,7 +249,7 @@ impl ApproxKvIndexer { ...@@ -248,7 +249,7 @@ impl ApproxKvIndexer {
timer_manager.insert(result.sequence_hashes.iter().map(|h| TimerEntry { timer_manager.insert(result.sequence_hashes.iter().map(|h| TimerEntry {
key: ExternalSequenceBlockHash(*h), key: ExternalSequenceBlockHash(*h),
worker: result.worker_id, worker: result.worker,
}).collect()); }).collect());
} }
...@@ -269,12 +270,13 @@ impl ApproxKvIndexer { ...@@ -269,12 +270,13 @@ impl ApproxKvIndexer {
event_id += 1; event_id += 1;
let event = RouterEvent::new( let event = RouterEvent::new(
e.worker, e.worker.worker_id,
KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![e.key], block_hashes: vec![e.key],
}), }),
dp_rank: e.worker.dp_rank,
} }
); );
...@@ -307,13 +309,13 @@ impl ApproxKvIndexer { ...@@ -307,13 +309,13 @@ impl ApproxKvIndexer {
/// Core function to process a routing decision with pre-computed hashes /// Core function to process a routing decision with pre-computed hashes
pub async fn process_routing_decision( pub async fn process_routing_decision(
&self, &self,
worker_id: WorkerId, worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>, local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>, sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
self.route_tx self.route_tx
.send(RouterResult { .send(RouterResult {
worker_id, worker,
local_hashes, local_hashes,
sequence_hashes, sequence_hashes,
}) })
...@@ -327,7 +329,7 @@ impl ApproxKvIndexer { ...@@ -327,7 +329,7 @@ impl ApproxKvIndexer {
pub async fn process_routing_decision_for_request( pub async fn process_routing_decision_for_request(
&self, &self,
tokens: &[u32], tokens: &[u32],
worker_id: WorkerId, worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size); let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size);
...@@ -338,7 +340,7 @@ impl ApproxKvIndexer { ...@@ -338,7 +340,7 @@ impl ApproxKvIndexer {
.map(|b| b.sequence_hash()) .map(|b| b.sequence_hash())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.process_routing_decision(worker_id, local_hashes, sequence_hashes) self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await .await
} }
} }
...@@ -526,14 +528,20 @@ mod tests { ...@@ -526,14 +528,20 @@ mod tests {
// 2. Inform indexer about routing decision // 2. Inform indexer about routing decision
indexer indexer
.process_routing_decision_for_request(&tokens, worker_id) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await .await
.unwrap(); .unwrap();
// Poll until we observe the match being registered // Poll until we observe the match being registered
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_id).copied() == Some(1) s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_id))
.copied()
== Some(1)
}) })
.await; .await;
...@@ -554,14 +562,18 @@ mod tests { ...@@ -554,14 +562,18 @@ mod tests {
let worker_id: WorkerId = 7; let worker_id: WorkerId = 7;
indexer indexer
.process_routing_decision_for_request(&tokens, worker_id) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await .await
.unwrap(); .unwrap();
// Wait until the worker is registered // Wait until the worker is registered
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.contains_key(&worker_id) s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
}) })
.await; .await;
...@@ -571,7 +583,8 @@ mod tests { ...@@ -571,7 +583,8 @@ mod tests {
// Ensure the worker's entries are gone // Ensure the worker's entries are gone
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
!s.scores.contains_key(&worker_id) !s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
}) })
.await; .await;
} }
...@@ -590,19 +603,31 @@ mod tests { ...@@ -590,19 +603,31 @@ mod tests {
// Register on both workers // Register on both workers
indexer indexer
.process_routing_decision_for_request(&tokens, worker_0) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await .await
.unwrap(); .unwrap();
indexer indexer
.process_routing_decision_for_request(&tokens, worker_1) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await .await
.unwrap(); .unwrap();
// Ensure both workers are registered // Ensure both workers are registered
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_0).copied() == Some(1) s.scores
&& s.scores.get(&worker_1).copied() == Some(1) .get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
== Some(1)
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
}) })
.await; .await;
...@@ -612,7 +637,12 @@ mod tests { ...@@ -612,7 +637,12 @@ mod tests {
// Confirm the removed worker is gone, and the other remains. // Confirm the removed worker is gone, and the other remains.
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
!s.scores.contains_key(&worker_0) && s.scores.get(&worker_1).copied() == Some(1) !s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
}) })
.await; .await;
} }
...@@ -631,14 +661,20 @@ mod tests { ...@@ -631,14 +661,20 @@ mod tests {
// Register Sequence A on worker A // Register Sequence A on worker A
indexer indexer
.process_routing_decision_for_request(&seq_a, worker_a) .process_routing_decision_for_request(
&seq_a,
WorkerWithDpRank::from_worker_id(worker_a),
)
.await .await
.unwrap(); .unwrap();
// Ensure the indexer has registered the block // Ensure the indexer has registered the block
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&seq_a).await.unwrap(); let s = indexer.find_matches_for_request(&seq_a).await.unwrap();
s.scores.get(&worker_a).copied() == Some(1) s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a))
.copied()
== Some(1)
}) })
.await; .await;
...@@ -649,7 +685,12 @@ mod tests { ...@@ -649,7 +685,12 @@ mod tests {
let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap(); let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap();
// Expect worker A to have an overlap score of 1 (shared first block) // Expect worker A to have an overlap score of 1 (shared first block)
assert_eq!(overlap.scores.get(&worker_a), Some(&1)); assert_eq!(
overlap
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a)),
Some(&1)
);
} }
/// When the same block resides on multiple workers, all should appear in the overlap scores. /// When the same block resides on multiple workers, all should appear in the overlap scores.
...@@ -666,25 +707,47 @@ mod tests { ...@@ -666,25 +707,47 @@ mod tests {
// Register the same sequence on two different workers // Register the same sequence on two different workers
indexer indexer
.process_routing_decision_for_request(&tokens, worker_0) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await .await
.unwrap(); .unwrap();
indexer indexer
.process_routing_decision_for_request(&tokens, worker_1) .process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await .await
.unwrap(); .unwrap();
// Wait until both workers are reflected in overlap scores // Wait until both workers are reflected in overlap scores
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_0).copied() == Some(1) s.scores
&& s.scores.get(&worker_1).copied() == Some(1) .get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
== Some(1)
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
}) })
.await; .await;
let scores = indexer.find_matches_for_request(&tokens).await.unwrap(); let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
assert_eq!(scores.scores.get(&worker_0), Some(&1)); assert_eq!(
assert_eq!(scores.scores.get(&worker_1), Some(&1)); scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0)),
Some(&1)
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1)),
Some(&1)
);
} }
} }
...@@ -80,9 +80,6 @@ pub enum KvCacheEventError { ...@@ -80,9 +80,6 @@ pub enum KvCacheEventError {
BlockNotFound, BlockNotFound,
} }
/// Identifier of a LLM worker which emits events to the router.
pub type WorkerId = i64;
/// A shared reference to a [`RadixBlock`]. /// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>; type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
...@@ -200,9 +197,9 @@ impl RouterEvent { ...@@ -200,9 +197,9 @@ impl RouterEvent {
struct RadixBlock { struct RadixBlock {
/// A map of child blocks, keyed by their local block hash. /// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedRadixBlock>, children: HashMap<LocalBlockHash, SharedRadixBlock>,
/// A map of worker IDs to their external sequence block hash for this block. /// A map of workers (with dp_rank) to their external sequence block hash for this block.
/// The external hash is preserved to speed up snapshotting. /// The external hash is preserved to speed up snapshotting.
workers: HashMap<WorkerId, ExternalSequenceBlockHash>, workers: HashMap<WorkerWithDpRank, ExternalSequenceBlockHash>,
/// A buffer of times that this block was last traversed /// A buffer of times that this block was last traversed
recent_uses: VecDeque<Instant>, recent_uses: VecDeque<Instant>,
} }
...@@ -235,7 +232,7 @@ pub struct RadixTree { ...@@ -235,7 +232,7 @@ pub struct RadixTree {
/// Transitioning to a radix tree only would require a change in the messaging structure /// Transitioning to a radix tree only would require a change in the messaging structure
/// as the entire prefix would need to be sent. Alternatively, we could use block_depth /// as the entire prefix would need to be sent. Alternatively, we could use block_depth
/// integers to indicate how many blocks to skip and use a radix/prefix tree at each level. /// integers to indicate how many blocks to skip and use a radix/prefix tree at each level.
lookup: HashMap<WorkerId, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>, lookup: HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
/// The time buffer the radix tree should check when considering frequence of block accesses /// The time buffer the radix tree should check when considering frequence of block accesses
expiration_duration: Option<Duration>, expiration_duration: Option<Duration>,
} }
...@@ -332,11 +329,15 @@ impl RadixTree { ...@@ -332,11 +329,15 @@ impl RadixTree {
/// ///
/// * `event` - The `RouterEvent` to apply. /// * `event` - The `RouterEvent` to apply.
pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> { pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> {
let (worker_id, event) = (event.worker_id, event.event); let (worker_id, kv_event) = (event.worker_id, event.event);
let (id, op) = (event.event_id, event.data); let (id, op) = (kv_event.event_id, kv_event.data);
// Construct WorkerWithDpRank from worker_id and dp_rank from the event
let worker = WorkerWithDpRank::new(worker_id, kv_event.dp_rank);
tracing::trace!(id, "RadixTree::apply_event: Store operation: {:?}", op); tracing::trace!(id, "RadixTree::apply_event: Store operation: {:?}", op);
let worker_lookup = self.lookup.entry(worker_id).or_default(); let worker_lookup = self.lookup.entry(worker).or_default();
match op { match op {
KvCacheEventData::Stored(op) => { KvCacheEventData::Stored(op) => {
...@@ -352,7 +353,8 @@ impl RadixTree { ...@@ -352,7 +353,8 @@ impl RadixTree {
Some(current) => current.clone(), Some(current) => current.clone(),
None => { None => {
tracing::warn!( tracing::warn!(
worker_id = worker_id.to_string(), worker_id = worker.worker_id.to_string(),
dp_rank = ?worker.dp_rank,
id, id,
parent_hash = ?op.parent_hash, parent_hash = ?op.parent_hash,
"Failed to find parent block; skipping store operation" "Failed to find parent block; skipping store operation"
...@@ -381,11 +383,11 @@ impl RadixTree { ...@@ -381,11 +383,11 @@ impl RadixTree {
} }
}; };
// add our worker_id to the block with its external hash // add our worker to the block with its external hash
block block
.borrow_mut() .borrow_mut()
.workers .workers
.insert(worker_id, block_id.block_hash); .insert(worker, block_id.block_hash);
// add the block to the worker_id lookup table // add the block to the worker_id lookup table
worker_lookup.insert(block_id.block_hash, block.clone()); worker_lookup.insert(block_id.block_hash, block.clone());
...@@ -419,7 +421,7 @@ impl RadixTree { ...@@ -419,7 +421,7 @@ impl RadixTree {
}; };
let mut guard = entry.borrow_mut(); let mut guard = entry.borrow_mut();
guard.workers.remove(&worker_id); guard.workers.remove(&worker);
if guard.workers.is_empty() { if guard.workers.is_empty() {
// if no workers are using this block, that is true for all children // if no workers are using this block, that is true for all children
guard.children.clear(); guard.children.clear();
...@@ -430,48 +432,57 @@ impl RadixTree { ...@@ -430,48 +432,57 @@ impl RadixTree {
Ok(()) Ok(())
} }
KvCacheEventData::Cleared => { KvCacheEventData::Cleared => {
self.clear_all_blocks(worker_id); self.clear_all_blocks(worker.worker_id);
Ok(()) Ok(())
} }
} }
} }
pub fn remove_worker(&mut self, worker: WorkerId) { /// Helper function to remove or clear blocks for a worker.
if let Some((_, blocks)) = self.lookup.remove_entry(&worker) { /// If `keep_worker` is true, the worker remains in lookup with empty blocks.
blocks.iter().for_each(|(_, block)| { /// If `keep_worker` is false, the worker is completely removed from lookup.
block.borrow_mut().workers.remove(&worker); fn remove_or_clear_worker_blocks(&mut self, worker_id: WorkerId, keep_worker: bool) {
// If no workers are using this block, that is true for all children // Collect all WorkerWithDpRank keys that match this worker_id
if block.borrow().workers.is_empty() { let workers: Vec<WorkerWithDpRank> = self
block.borrow_mut().children.clear(); .lookup
} .keys()
}); .filter(|w| w.worker_id == worker_id)
} .copied()
} .collect();
pub fn clear_all_blocks(&mut self, worker: WorkerId) { for worker in workers {
// Check if the worker has any blocks to clear if let Some((worker_key, blocks)) = self.lookup.remove_entry(&worker) {
if let Some(blocks) = self.lookup.get(&worker) { blocks.iter().for_each(|(_, block)| {
let blocks_to_clear: Vec<_> = blocks.values().collect(); block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
});
// Remove the worker from each block's workers map if keep_worker {
blocks_to_clear.iter().for_each(|block| { // Re-insert worker with empty blocks map to keep it tracked
block.borrow_mut().workers.remove(&worker); self.lookup.insert(worker_key, HashMap::new());
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
} }
});
// Clear the worker's blocks
if let Some(worker_lookup) = self.lookup.get_mut(&worker) {
worker_lookup.clear();
} }
} }
} }
pub fn remove_worker(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, false);
}
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
}
/// Get all worker IDs currently tracked in the radix tree. /// Get all worker IDs currently tracked in the radix tree.
/// Returns unique worker_ids (ignoring dp_rank differences).
pub fn get_workers(&self) -> Vec<WorkerId> { pub fn get_workers(&self) -> Vec<WorkerId> {
self.lookup.keys().copied().collect() let mut worker_ids: Vec<WorkerId> = self.lookup.keys().map(|w| w.worker_id).collect();
worker_ids.sort_unstable();
worker_ids.dedup();
worker_ids
} }
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree. /// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
...@@ -487,10 +498,10 @@ impl RadixTree { ...@@ -487,10 +498,10 @@ impl RadixTree {
let mut event_id = 0u64; let mut event_id = 0u64;
// BFS queue: (current_block, parent_hashes_per_worker, tokens_hash) // BFS queue: (current_block, parent_hashes_per_worker, tokens_hash)
// parent_hashes_per_worker maps WorkerId -> ExternalSequenceBlockHash // parent_hashes_per_worker maps WorkerWithDpRank -> ExternalSequenceBlockHash
let mut queue: VecDeque<( let mut queue: VecDeque<(
SharedRadixBlock, SharedRadixBlock,
HashMap<WorkerId, ExternalSequenceBlockHash>, HashMap<WorkerWithDpRank, ExternalSequenceBlockHash>,
LocalBlockHash, LocalBlockHash,
)> = VecDeque::new(); )> = VecDeque::new();
...@@ -514,7 +525,7 @@ impl RadixTree { ...@@ -514,7 +525,7 @@ impl RadixTree {
// Create a store event for this worker // Create a store event for this worker
let event = RouterEvent { let event = RouterEvent {
worker_id: *worker_id, worker_id: worker_id.worker_id,
event: KvCacheEvent { event: KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
...@@ -524,6 +535,7 @@ impl RadixTree { ...@@ -524,6 +535,7 @@ impl RadixTree {
tokens_hash, tokens_hash,
}], }],
}), }),
dp_rank: worker_id.dp_rank,
}, },
}; };
events.push(event); events.push(event);
...@@ -639,11 +651,11 @@ impl KvIndexerMetrics { ...@@ -639,11 +651,11 @@ impl KvIndexerMetrics {
} }
} }
/// Scores representing the overlap of workers. /// Scores representing the overlap of workers (with their dp_rank).
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverlapScores { pub struct OverlapScores {
// map of worker_id to score // map of worker (with dp_rank) to score
pub scores: HashMap<WorkerId, u32>, pub scores: HashMap<WorkerWithDpRank, u32>,
// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted. // List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub frequencies: Vec<usize>, pub frequencies: Vec<usize>,
} }
...@@ -671,10 +683,10 @@ impl OverlapScores { ...@@ -671,10 +683,10 @@ impl OverlapScores {
/// ///
/// ### Arguments /// ### Arguments
/// ///
/// * `workers` - An iterator over `WorkerId` references. /// * `workers` - An iterator over `WorkerWithDpRank` references.
pub fn update_scores<'a, I>(&mut self, workers: I) pub fn update_scores<'a, I>(&mut self, workers: I)
where where
I: IntoIterator<Item = &'a WorkerId>, I: IntoIterator<Item = &'a WorkerWithDpRank>,
{ {
for worker in workers { for worker in workers {
let score = self.scores.entry(*worker).or_insert(0); let score = self.scores.entry(*worker).or_insert(0);
...@@ -1344,6 +1356,7 @@ mod tests { ...@@ -1344,6 +1356,7 @@ mod tests {
event: KvCacheEvent { event: KvCacheEvent {
event_id, event_id,
data: add_blocks(hashes, parent), data: add_blocks(hashes, parent),
dp_rank: 0,
}, },
} }
} }
...@@ -1359,6 +1372,7 @@ mod tests { ...@@ -1359,6 +1372,7 @@ mod tests {
.map(|i| ExternalSequenceBlockHash(*i * 100)) .map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(), .collect(),
}), }),
dp_rank: 0,
}, },
} }
} }
...@@ -1379,10 +1393,22 @@ mod tests { ...@@ -1379,10 +1393,22 @@ mod tests {
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)], vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false, false,
); );
assert_eq!(scores.scores.get(&worker_1).unwrap(), &3); assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(trie.lookup.len(), 1); assert_eq!(trie.lookup.len(), 1);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3); assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(trie.root.borrow().workers.len(), 0); assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1); assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!( assert_eq!(
...@@ -1415,12 +1441,36 @@ mod tests { ...@@ -1415,12 +1441,36 @@ mod tests {
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)], vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false, false,
); );
assert_eq!(scores.scores.get(&worker_1).unwrap(), &3); assert_eq!(
assert_eq!(scores.scores.get(&worker_2).unwrap(), &1); scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&1
);
assert_eq!(trie.lookup.len(), 2); assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3); assert_eq!(
assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 3); trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
3
);
assert_eq!(trie.root.borrow().workers.len(), 0); assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1); assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!( assert_eq!(
...@@ -1449,8 +1499,20 @@ mod tests { ...@@ -1449,8 +1499,20 @@ mod tests {
trie.apply_event(create_remove_event(worker_2, 2, vec![5])) trie.apply_event(create_remove_event(worker_2, 2, vec![5]))
.unwrap(); .unwrap();
assert_eq!(trie.lookup.len(), 2); assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3); assert_eq!(
assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 2); trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
2
);
assert_eq!(trie.root.borrow().workers.len(), 0); assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1); assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!( assert_eq!(
...@@ -1480,8 +1542,20 @@ mod tests { ...@@ -1480,8 +1542,20 @@ mod tests {
.unwrap(); .unwrap();
assert_eq!(trie.lookup.len(), 2); assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3); assert_eq!(
assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 1); trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
1
);
assert_eq!(trie.root.borrow().workers.len(), 0); assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1); assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!( assert_eq!(
...@@ -1519,12 +1593,36 @@ mod tests { ...@@ -1519,12 +1593,36 @@ mod tests {
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)], vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false, false,
); );
assert_eq!(scores.scores.get(&worker_1).unwrap(), &3); assert_eq!(
assert_eq!(scores.scores.get(&worker_2).unwrap(), &2); scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&2
);
assert_eq!(trie.lookup.len(), 2); assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3); assert_eq!(
assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 4); trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
4
);
assert_eq!(trie.root.borrow().workers.len(), 0); assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1); assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!( assert_eq!(
...@@ -1551,7 +1649,7 @@ mod tests { ...@@ -1551,7 +1649,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
trie.lookup trie.lookup
.get(&worker_1) .get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap() .unwrap()
.get(&ExternalSequenceBlockHash(200)) .get(&ExternalSequenceBlockHash(200))
.unwrap() .unwrap()
...@@ -1562,7 +1660,7 @@ mod tests { ...@@ -1562,7 +1660,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
trie.lookup trie.lookup
.get(&worker_2) .get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap() .unwrap()
.get(&ExternalSequenceBlockHash(200)) .get(&ExternalSequenceBlockHash(200))
.unwrap() .unwrap()
...@@ -1620,12 +1718,16 @@ mod tests { ...@@ -1620,12 +1718,16 @@ mod tests {
.unwrap(); .unwrap();
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores; let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1); assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 1
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
trie.remove_worker(worker_0); trie.remove_worker(worker_0);
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores; let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(result.len() == 1 && result[&worker_1] == 1); assert!(result.len() == 1 && result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1);
} }
#[test] #[test]
...@@ -1643,7 +1745,11 @@ mod tests { ...@@ -1643,7 +1745,11 @@ mod tests {
// Test clearing an empty worker // Test clearing an empty worker
trie.clear_all_blocks(worker_0); trie.clear_all_blocks(worker_0);
assert!(!trie.lookup.contains_key(&worker_0)); assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
// Test clearing a worker with shared blocks // Test clearing a worker with shared blocks
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None)) trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None))
...@@ -1652,17 +1758,29 @@ mod tests { ...@@ -1652,17 +1758,29 @@ mod tests {
.unwrap(); .unwrap();
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores; let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1); assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 1
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
trie.clear_all_blocks(worker_0); trie.clear_all_blocks(worker_0);
assert!(trie.lookup.contains_key(&worker_0)); assert!(
assert!(trie.lookup.get(&worker_0).unwrap().is_empty()); trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.is_empty()
);
let result = trie let result = trie
.find_matches(vec![LocalBlockHash(0), LocalBlockHash(2)], false) .find_matches(vec![LocalBlockHash(0), LocalBlockHash(2)], false)
.scores; .scores;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 2); assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 2);
let result = trie let result = trie
.find_matches( .find_matches(
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(3)], vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(3)],
...@@ -1670,7 +1788,7 @@ mod tests { ...@@ -1670,7 +1788,7 @@ mod tests {
) )
.scores; .scores;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 1); assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
// Test re-adding blocks after clearing worker // Test re-adding blocks after clearing worker
trie.apply_event(create_store_event(worker_0, 0, vec![4, 5], None)) trie.apply_event(create_store_event(worker_0, 0, vec![4, 5], None))
...@@ -1679,19 +1797,32 @@ mod tests { ...@@ -1679,19 +1797,32 @@ mod tests {
.find_matches(vec![LocalBlockHash(4), LocalBlockHash(5)], false) .find_matches(vec![LocalBlockHash(4), LocalBlockHash(5)], false)
.scores; .scores;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[&worker_0], 2); assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_0)], 2);
// Test multiple clears // Test multiple clears
trie.clear_all_blocks(worker_0); trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_0); trie.clear_all_blocks(worker_0);
assert!(trie.lookup.contains_key(&worker_0)); assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
// Test clearing all workers // Test clearing all workers
trie.clear_all_blocks(worker_0); trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_1); trie.clear_all_blocks(worker_1);
assert!(!trie.lookup.is_empty()); assert!(!trie.lookup.is_empty());
assert!(trie.lookup.get(&worker_0).unwrap().is_empty()); assert!(
assert!(trie.lookup.get(&worker_1).unwrap().is_empty()); trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.is_empty()
);
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.is_empty()
);
// Test clearing a worker that has been removed // Test clearing a worker that has been removed
trie.apply_event(create_store_event(worker_0, 0, vec![6], None)) trie.apply_event(create_store_event(worker_0, 0, vec![6], None))
...@@ -1700,20 +1831,35 @@ mod tests { ...@@ -1700,20 +1831,35 @@ mod tests {
.unwrap(); .unwrap();
trie.remove_worker(worker_0); trie.remove_worker(worker_0);
trie.clear_all_blocks(worker_0); trie.clear_all_blocks(worker_0);
assert!(!trie.lookup.contains_key(&worker_0)); assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores; let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 1); assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
// Test clearing a worker that doesn't exist // Test clearing a worker that doesn't exist
let worker_fake = 2; let worker_fake = 2;
assert!(!trie.lookup.contains_key(&worker_fake)); assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_fake))
);
trie.clear_all_blocks(worker_fake); trie.clear_all_blocks(worker_fake);
assert!(!trie.lookup.contains_key(&worker_fake)); assert!(
assert!(trie.lookup.contains_key(&worker_1)); !trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_fake))
);
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores; let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 1); assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
} }
#[test] #[test]
...@@ -1736,12 +1882,20 @@ mod tests { ...@@ -1736,12 +1882,20 @@ mod tests {
) )
.scores; .scores;
assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1); assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
let result = trie let result = trie
.find_matches(vec![LocalBlockHash(0), LocalBlockHash(1)], true) .find_matches(vec![LocalBlockHash(0), LocalBlockHash(1)], true)
.scores; .scores;
assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1); assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
} }
#[rstest] #[rstest]
...@@ -1968,6 +2122,7 @@ mod tests { ...@@ -1968,6 +2122,7 @@ mod tests {
tokens_hash: LocalBlockHash(13226331709069118873), tokens_hash: LocalBlockHash(13226331709069118873),
}], }],
}), }),
dp_rank: 0,
}; };
let router_event = RouterEvent::new(worker_id, kv_cache_event); let router_event = RouterEvent::new(worker_id, kv_cache_event);
...@@ -2239,49 +2394,94 @@ mod tests { ...@@ -2239,49 +2394,94 @@ mod tests {
.unwrap(); .unwrap();
// Verify worker_0 has 3 blocks in lookup // Verify worker_0 has 3 blocks in lookup
assert_eq!(trie.lookup.get(&worker_0).unwrap().len(), 3); assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.len(),
3
);
// Verify that blocks have the correct workers // Verify that blocks have the correct workers
let block_1 = trie let block_1 = trie
.lookup .lookup
.get(&worker_0) .get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap() .unwrap()
.get(&ExternalSequenceBlockHash(100)) .get(&ExternalSequenceBlockHash(100))
.unwrap(); .unwrap();
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1) assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
assert!(block_1.borrow().workers.contains_key(&worker_0)); assert!(
assert!(block_1.borrow().workers.contains_key(&worker_1)); block_1
assert!(block_1.borrow().workers.contains_key(&worker_2)); .borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Remove worker_0 // Remove worker_0
trie.remove_worker(worker_0); trie.remove_worker(worker_0);
// Verify worker_0 is completely removed from lookup table // Verify worker_0 is completely removed from lookup table
assert!(!trie.lookup.contains_key(&worker_0)); assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert_eq!(trie.lookup.len(), 2); assert_eq!(trie.lookup.len(), 2);
// Verify that worker_0's hash is removed from the workers set // Verify that worker_0's hash is removed from the workers set
let block_1 = trie let block_1 = trie
.lookup .lookup
.get(&worker_1) .get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap() .unwrap()
.get(&ExternalSequenceBlockHash(100)) .get(&ExternalSequenceBlockHash(100))
.unwrap(); .unwrap();
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
assert!(!block_1.borrow().workers.contains_key(&worker_0)); assert!(
assert!(block_1.borrow().workers.contains_key(&worker_1)); !block_1
assert!(block_1.borrow().workers.contains_key(&worker_2)); .borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Verify that blocks with no remaining workers have their children cleared // Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children // This tests the optimization where empty blocks clear their children
let block_2 = trie let block_2 = trie
.lookup .lookup
.get(&worker_1) .get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap() .unwrap()
.get(&ExternalSequenceBlockHash(200)) .get(&ExternalSequenceBlockHash(200))
.unwrap(); .unwrap();
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1 assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
assert!(block_2.borrow().workers.contains_key(&worker_1)); assert!(
block_2
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
// Verify match results no longer include worker_0 // Verify match results no longer include worker_0
let result = trie let result = trie
...@@ -2291,8 +2491,8 @@ mod tests { ...@@ -2291,8 +2491,8 @@ mod tests {
) )
.scores; .scores;
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
assert!(!result.contains_key(&worker_0)); assert!(!result.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)));
assert!(result.contains_key(&worker_1)); assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)));
assert!(result.contains_key(&worker_2)); assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_2)));
} }
} }
...@@ -5,6 +5,34 @@ use crate::tokens::{SequenceHash, Token}; ...@@ -5,6 +5,34 @@ use crate::tokens::{SequenceHash, Token};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
/// A worker identifier.
pub type WorkerId = i64;
/// A data parallel rank identifier.
pub type DpRank = u32;
/// A worker identifier combined with its data parallel rank.
/// Used for routing decisions in data parallel setups.
/// dp_rank = 0 indicates either DP not enabled or the first rank.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct WorkerWithDpRank {
pub worker_id: WorkerId,
pub dp_rank: DpRank,
}
impl WorkerWithDpRank {
pub fn new(worker_id: WorkerId, dp_rank: DpRank) -> Self {
Self { worker_id, dp_rank }
}
pub fn from_worker_id(worker_id: WorkerId) -> Self {
Self {
worker_id,
dp_rank: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")] #[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterRequest { pub enum RouterRequest {
...@@ -26,15 +54,24 @@ impl Default for RouterRequest { ...@@ -26,15 +54,24 @@ impl Default for RouterRequest {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")] #[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterResponse { pub enum RouterResponse {
New { worker_id: i64, overlap_blocks: u32 }, New {
PrefillMarked { success: bool }, worker_id: WorkerId,
FreeMarked { success: bool }, #[serde(default)]
dp_rank: DpRank,
overlap_blocks: u32,
},
PrefillMarked {
success: bool,
},
FreeMarked {
success: bool,
},
} }
#[derive(Debug)] #[derive(Debug)]
pub struct WorkerSelectionResult { pub struct WorkerSelectionResult {
/// The worker id of the selected worker /// The full worker information including dp_rank
pub worker_id: i64, pub worker: WorkerWithDpRank,
/// The total number of blocks required to prefill the request /// The total number of blocks required to prefill the request
pub required_blocks: u64, pub required_blocks: u64,
...@@ -54,7 +91,7 @@ pub struct ForwardPassMetrics { ...@@ -54,7 +91,7 @@ pub struct ForwardPassMetrics {
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct WorkerStats { pub struct WorkerStats {
// https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models // https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
pub data_parallel_rank: Option<u32>, pub data_parallel_rank: Option<DpRank>,
pub request_active_slots: u64, pub request_active_slots: u64,
pub request_total_slots: u64, pub request_total_slots: u64,
pub num_requests_waiting: u64, pub num_requests_waiting: u64,
...@@ -136,7 +173,7 @@ impl From<i64> for ExternalSequenceBlockHash { ...@@ -136,7 +173,7 @@ impl From<i64> for ExternalSequenceBlockHash {
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillEvent { pub struct PrefillEvent {
pub request_id: String, pub request_id: String,
pub worker_id: i64, pub worker_id: WorkerId,
pub data: PrefillEventData, pub data: PrefillEventData,
pub router_id: Uuid, pub router_id: Uuid,
} }
...@@ -155,7 +192,7 @@ pub enum PrefillEventData { ...@@ -155,7 +192,7 @@ pub enum PrefillEventData {
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ActiveSequenceEvent { pub struct ActiveSequenceEvent {
pub request_id: String, pub request_id: String,
pub worker_id: i64, pub worker: WorkerWithDpRank,
pub data: ActiveSequenceEventData, pub data: ActiveSequenceEventData,
pub router_id: Uuid, pub router_id: Uuid,
} }
...@@ -199,6 +236,9 @@ pub struct KvCacheEvent { ...@@ -199,6 +236,9 @@ pub struct KvCacheEvent {
pub event_id: u64, pub event_id: u64,
/// The data associated with the event. /// The data associated with the event.
pub data: KvCacheEventData, pub data: KvCacheEventData,
/// The data parallel rank of the worker emitting this event (0 if DP not enabled).
#[serde(default)]
pub dp_rank: DpRank,
} }
/// Represents the data associated with a cache event. /// Represents the data associated with a cache event.
...@@ -313,6 +353,7 @@ mod tests { ...@@ -313,6 +353,7 @@ mod tests {
let event = KvCacheEvent { let event = KvCacheEvent {
event_id: 1, event_id: 1,
data: event_data, data: event_data,
dp_rank: 0,
}; };
let events = KvCacheEvents { let events = KvCacheEvents {
......
...@@ -326,13 +326,16 @@ pub async fn start_zmq_listener( ...@@ -326,13 +326,16 @@ pub async fn start_zmq_listener(
}; };
tracing::trace!( tracing::trace!(
"ZMQ listener on {} received batch with {} events (seq={})", "ZMQ listener on {} received batch with {} events (seq={}, dp_rank={})",
zmq_endpoint, zmq_endpoint,
batch.events.len(), batch.events.len(),
seq seq,
batch.data_parallel_rank
); );
let dp_rank = batch.data_parallel_rank;
for raw_event in batch.events.into_iter() { for raw_event in batch.events.into_iter() {
let event = convert_event(raw_event, seq, kv_block_size, &warning_count); let event = convert_event(raw_event, seq, kv_block_size, dp_rank, &warning_count);
if tx.send(event).is_err() { if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped"); tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped"; exit_reason = "channel receiver dropped";
...@@ -356,6 +359,7 @@ fn convert_event( ...@@ -356,6 +359,7 @@ fn convert_event(
raw: RawKvEvent, raw: RawKvEvent,
event_id: u64, event_id: u64,
kv_block_size: u32, kv_block_size: u32,
dp_rank: u32,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent { ) -> KvCacheEvent {
match raw { match raw {
...@@ -387,6 +391,7 @@ fn convert_event( ...@@ -387,6 +391,7 @@ fn convert_event(
warning_count, warning_count,
), ),
}), }),
dp_rank,
} }
} }
RawKvEvent::BlockRemoved { block_hashes, .. } => { RawKvEvent::BlockRemoved { block_hashes, .. } => {
...@@ -400,11 +405,13 @@ fn convert_event( ...@@ -400,11 +405,13 @@ fn convert_event(
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes, block_hashes: hashes,
}), }),
dp_rank,
} }
} }
RawKvEvent::AllBlocksCleared => KvCacheEvent { RawKvEvent::AllBlocksCleared => KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Cleared, data: KvCacheEventData::Cleared,
dp_rank,
}, },
} }
} }
...@@ -1014,7 +1021,7 @@ mod test_event_processing { ...@@ -1014,7 +1021,7 @@ mod test_event_processing {
medium: None, medium: None,
}; };
let out = convert_event(raw_evt, 42, kv_block_size, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Stored(_))); assert!(matches!(out.data, KvCacheEventData::Stored(_)));
} }
...@@ -1025,7 +1032,7 @@ mod test_event_processing { ...@@ -1025,7 +1032,7 @@ mod test_event_processing {
block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)], block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)],
medium: None, medium: None,
}; };
let out = convert_event(raw_evt, 7, kv_block_size, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 7, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Removed(_))); assert!(matches!(out.data, KvCacheEventData::Removed(_)));
} }
...@@ -1034,7 +1041,7 @@ mod test_event_processing { ...@@ -1034,7 +1041,7 @@ mod test_event_processing {
fn test_convert_event_all_blocks_cleared() { fn test_convert_event_all_blocks_cleared() {
let kv_block_size = 4; let kv_block_size = 4;
let raw_evt = RawKvEvent::AllBlocksCleared; let raw_evt = RawKvEvent::AllBlocksCleared;
let out = convert_event(raw_evt, 1, kv_block_size, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Cleared)); assert!(matches!(out.data, KvCacheEventData::Cleared));
} }
} }
...@@ -1115,6 +1122,7 @@ mod tests_startup_helpers { ...@@ -1115,6 +1122,7 @@ mod tests_startup_helpers {
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)], block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)],
}), }),
dp_rank: 0,
}; };
let token = CancellationToken::new(); let token = CancellationToken::new();
......
...@@ -12,7 +12,6 @@ mod tests { ...@@ -12,7 +12,6 @@ mod tests {
use super::*; use super::*;
use crate::kv_router::indexer::KvIndexer; use crate::kv_router::indexer::KvIndexer;
use crate::kv_router::indexer::KvIndexerMetrics; use crate::kv_router::indexer::KvIndexerMetrics;
use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::*; use crate::kv_router::protocols::*;
use std::time::Duration; use std::time::Duration;
use tempfile::tempdir; use tempfile::tempdir;
...@@ -50,6 +49,7 @@ mod tests { ...@@ -50,6 +49,7 @@ mod tests {
KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: add_blocks(hashes, parent), data: add_blocks(hashes, parent),
dp_rank: 0,
}, },
) )
} }
...@@ -65,6 +65,7 @@ mod tests { ...@@ -65,6 +65,7 @@ mod tests {
.map(|i| ExternalSequenceBlockHash(*i * 100)) .map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(), .collect(),
}), }),
dp_rank: 0,
}, },
) )
} }
......
...@@ -18,21 +18,24 @@ use super::KvRouterConfig; ...@@ -18,21 +18,24 @@ use super::KvRouterConfig;
use super::RouterConfigOverride; use super::RouterConfigOverride;
use super::WorkerSelector; use super::WorkerSelector;
use super::indexer::OverlapScores; use super::indexer::OverlapScores;
use super::protocols::WorkerSelectionResult; use super::protocols::{DpRank, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::sequence::ActiveSequencesMultiWorker; use super::sequence::ActiveSequencesMultiWorker;
use crate::tokens::SequenceHash; use crate::tokens::SequenceHash;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent { pub struct KVHitRateEvent {
pub worker_id: i64, pub worker_id: WorkerId,
#[serde(default)]
pub dp_rank: DpRank,
pub isl_blocks: usize, pub isl_blocks: usize,
pub overlap_blocks: u32, pub overlap_blocks: u32,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad { pub struct PotentialLoad {
pub worker_id: i64, pub worker_id: WorkerId,
pub dp_rank: DpRank,
pub potential_prefill_tokens: usize, pub potential_prefill_tokens: usize,
pub potential_decode_blocks: usize, pub potential_decode_blocks: usize,
} }
...@@ -51,7 +54,7 @@ pub enum KvSchedulerError { ...@@ -51,7 +54,7 @@ pub enum KvSchedulerError {
#[derive(Debug)] #[derive(Debug)]
pub struct SchedulingResponse { pub struct SchedulingResponse {
pub best_worker_id: i64, pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32, pub overlap_blocks: u32,
} }
...@@ -60,8 +63,8 @@ pub struct SchedulingRequest { ...@@ -60,8 +63,8 @@ pub struct SchedulingRequest {
pub token_seq: Option<Vec<SequenceHash>>, pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize, pub isl_tokens: usize,
pub overlaps: OverlapScores, pub overlaps: OverlapScores,
pub decode_blocks: HashMap<i64, usize>, pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: HashMap<i64, usize>, pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
// Router config overrides for this specific request // Router config overrides for this specific request
pub router_config_override: Option<RouterConfigOverride>, pub router_config_override: Option<RouterConfigOverride>,
// Whether to update scheduler states (false for query_instance_id requests) // Whether to update scheduler states (false for query_instance_id requests)
...@@ -94,17 +97,18 @@ impl KvScheduler { ...@@ -94,17 +97,18 @@ impl KvScheduler {
component: Component, component: Component,
block_size: u32, block_size: u32,
instances_rx: watch::Receiver<Vec<Instance>>, instances_rx: watch::Receiver<Vec<Instance>>,
runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>, runtime_configs_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool, replica_sync: bool,
router_uuid: String, router_uuid: String,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone(); let instances: Vec<Instance> = instances_rx.borrow().clone();
let runtime_configs: HashMap<i64, ModelRuntimeConfig> = runtime_configs_rx.borrow().clone(); let runtime_configs: HashMap<WorkerId, ModelRuntimeConfig> =
runtime_configs_rx.borrow().clone();
// Create shared workers_with_configs wrapped in Arc<RwLock> // Create shared workers_with_configs wrapped in Arc<RwLock>
let workers_with_configs: Arc<RwLock<HashMap<i64, Option<ModelRuntimeConfig>>>> = { let workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>> = {
let mut initial_map = HashMap::new(); let mut initial_map = HashMap::new();
for instance in &instances { for instance in &instances {
let worker_id = instance.instance_id; let worker_id = instance.instance_id;
...@@ -117,14 +121,10 @@ impl KvScheduler { ...@@ -117,14 +121,10 @@ impl KvScheduler {
Arc::new(RwLock::new(initial_map)) Arc::new(RwLock::new(initial_map))
}; };
let worker_ids: Vec<i64> = instances
.iter()
.map(|instance| instance.instance_id)
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new( let slots = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(), component.clone(),
block_size as usize, block_size as usize,
worker_ids, workers_with_configs.read().await.clone(), // this includes dp_size info
replica_sync, replica_sync,
router_uuid, router_uuid,
)); ));
...@@ -162,24 +162,23 @@ impl KvScheduler { ...@@ -162,24 +162,23 @@ impl KvScheduler {
let new_instances = instances_monitor_rx.borrow_and_update().clone(); let new_instances = instances_monitor_rx.borrow_and_update().clone();
let new_configs = configs_monitor_rx.borrow_and_update().clone(); let new_configs = configs_monitor_rx.borrow_and_update().clone();
// Update workers when instances change // Build the new workers_with_configs map
let worker_ids: Vec<i64> = new_instances let mut new_workers_with_configs = HashMap::new();
.iter()
.map(|instance| instance.instance_id)
.collect();
slots_monitor.update_workers(worker_ids);
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
workers_map.clear();
for instance in &new_instances { for instance in &new_instances {
let worker_id = instance.instance_id; let worker_id = instance.instance_id;
let config = new_configs.get(&worker_id).cloned(); let config = new_configs.get(&worker_id).cloned();
if config.is_some() { if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id); tracing::info!("Runtime config found for worker_id: {}", worker_id);
} }
workers_map.insert(worker_id, config); new_workers_with_configs.insert(worker_id, config);
} }
// Update workers when instances change
slots_monitor.update_workers(new_workers_with_configs.clone());
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
*workers_map = new_workers_with_configs;
tracing::trace!( tracing::trace!(
"Updated workers_with_configs with {} workers", "Updated workers_with_configs with {} workers",
workers_map.len() workers_map.len()
...@@ -229,7 +228,8 @@ impl KvScheduler { ...@@ -229,7 +228,8 @@ impl KvScheduler {
match selector.select_worker(&workers, &request, block_size) { match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => { Ok(selection) => {
let event = KVHitRateEvent { let event = KVHitRateEvent {
worker_id: selection.worker_id, worker_id: selection.worker.worker_id,
dp_rank: selection.worker.dp_rank,
isl_blocks: selection.required_blocks as usize, isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks, overlap_blocks: selection.overlap_blocks,
}; };
...@@ -238,7 +238,7 @@ impl KvScheduler { ...@@ -238,7 +238,7 @@ impl KvScheduler {
} }
let response = SchedulingResponse { let response = SchedulingResponse {
best_worker_id: selection.worker_id, best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks, overlap_blocks: selection.overlap_blocks,
}; };
request.respond(response); request.respond(response);
...@@ -261,7 +261,7 @@ impl KvScheduler { ...@@ -261,7 +261,7 @@ impl KvScheduler {
request.token_seq, request.token_seq,
request.isl_tokens, request.isl_tokens,
selection.overlap_blocks, selection.overlap_blocks,
selection.worker_id, selection.worker,
) )
.await .await
{ {
...@@ -302,7 +302,7 @@ impl KvScheduler { ...@@ -302,7 +302,7 @@ impl KvScheduler {
overlaps: OverlapScores, overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
) -> Result<i64, KvSchedulerError> { ) -> Result<WorkerWithDpRank, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
maybe_request_id, maybe_request_id,
...@@ -324,8 +324,7 @@ impl KvScheduler { ...@@ -324,8 +324,7 @@ impl KvScheduler {
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
let best_worker_id = response.best_worker_id; Ok(response.best_worker)
Ok(best_worker_id)
} }
pub async fn add_request( pub async fn add_request(
...@@ -334,11 +333,11 @@ impl KvScheduler { ...@@ -334,11 +333,11 @@ impl KvScheduler {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
worker_id: i64, worker: WorkerWithDpRank,
) { ) {
let _ = self let _ = self
.slots .slots
.add_request(request_id, token_sequence, isl, overlap, worker_id) .add_request(request_id, token_sequence, isl, overlap, worker)
.await; .await;
} }
...@@ -363,21 +362,22 @@ impl KvScheduler { ...@@ -363,21 +362,22 @@ impl KvScheduler {
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps) .potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
.await; .await;
// Get all unique worker IDs from both hashmaps // Get all unique WorkerWithDpRank from both hashmaps
let mut worker_ids: HashSet<i64> = HashSet::new(); let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
worker_ids.extend(decode_blocks.keys().copied()); workers.extend(decode_blocks.keys().copied());
worker_ids.extend(prefill_tokens.keys().copied()); workers.extend(prefill_tokens.keys().copied());
// Create PotentialLoad for each worker // Create PotentialLoad for each worker
let mut loads = Vec::new(); let mut loads = Vec::new();
for worker_id in worker_ids { for worker in workers {
loads.push(PotentialLoad { loads.push(PotentialLoad {
worker_id, worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens potential_prefill_tokens: prefill_tokens
.get(&worker_id) .get(&worker)
.copied() .copied()
.unwrap_or(isl_tokens), .unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker_id).copied().unwrap_or(0), potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
}); });
} }
...@@ -386,7 +386,7 @@ impl KvScheduler { ...@@ -386,7 +386,7 @@ impl KvScheduler {
} }
// Helper function for softmax sampling // Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 { fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) -> WorkerWithDpRank {
if logits.is_empty() { if logits.is_empty() {
panic!("Empty logits for softmax sampling"); panic!("Empty logits for softmax sampling");
} }
...@@ -474,7 +474,7 @@ impl DefaultWorkerSelector { ...@@ -474,7 +474,7 @@ impl DefaultWorkerSelector {
impl WorkerSelector for DefaultWorkerSelector { impl WorkerSelector for DefaultWorkerSelector {
fn select_worker( fn select_worker(
&self, &self,
workers: &HashMap<i64, Option<ModelRuntimeConfig>>, workers: &HashMap<WorkerId, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: u32, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> { ) -> Result<WorkerSelectionResult, KvSchedulerError> {
...@@ -494,38 +494,52 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -494,38 +494,52 @@ impl WorkerSelector for DefaultWorkerSelector {
let mut worker_logits = HashMap::new(); let mut worker_logits = HashMap::new();
let mut max_logit = f64::NEG_INFINITY; let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker // Calculate logits for each worker with dp_rank
for worker_id in workers.keys() { // Outer loop: iterate over all workers from runtime config
let overlap = *overlaps.get(worker_id).unwrap_or(&0); // Inner loop: iterate over all dp_ranks for each worker
for (worker_id, config) in workers.iter() {
// this is the number of prefill tokens the worker would have if the request were scheduled there // Get data_parallel_size from runtime config
let prefill_token = *prefill_tokens.get(worker_id).unwrap_or(&isl); // data_parallel_size defaults to 1 in ModelRuntimeConfig
let potential_prefill_block = (prefill_token as f64) / (block_size as f64); let data_parallel_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1); // Fallback if config is None
// this is the number of decode blocks the worker would have if the request were scheduled there // Iterate over all dp_ranks for this worker
let decode_block = *decode_blocks for dp_rank in 0..data_parallel_size {
.get(worker_id) let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
.unwrap_or(&(potential_prefill_block.floor() as usize))
as f64; // Get overlap for this worker (defaults to 0 if not in overlaps)
let overlap = *overlaps.get(&worker).unwrap_or(&0);
// Use override if provided, otherwise use default config
let overlap_weight = request // this is the number of prefill tokens the worker would have if the request were scheduled there
.router_config_override let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
.as_ref() let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
.and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight); // this is the number of decode blocks the worker would have if the request were scheduled there
let decode_block = *decode_blocks
// Calculate logit (lower is better) .get(&worker)
let logit = overlap_weight * potential_prefill_block + decode_block; .unwrap_or(&(potential_prefill_block.floor() as usize))
max_logit = max_logit.max(logit); as f64;
worker_logits.insert(*worker_id, logit); // Use override if provided, otherwise use default config
let overlap_weight = request
tracing::info!( .router_config_override
"Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \ .as_ref()
= {overlap_weight:.1} * prefill_blocks + decode_blocks \ .and_then(|cfg| cfg.overlap_score_weight)
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}" .unwrap_or(self.kv_router_config.overlap_score_weight);
);
// Calculate logit (lower is better)
let logit = overlap_weight * potential_prefill_block + decode_block;
max_logit = max_logit.max(logit);
worker_logits.insert(worker, logit);
tracing::info!(
"Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
worker.dp_rank
);
}
} }
// Use softmax sampling to select worker // Use softmax sampling to select worker
...@@ -535,29 +549,32 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -535,29 +549,32 @@ impl WorkerSelector for DefaultWorkerSelector {
.as_ref() .as_ref()
.and_then(|cfg| cfg.router_temperature) .and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature); .unwrap_or(self.kv_router_config.router_temperature);
let best_worker_id = softmax_sample(&worker_logits, temperature); let best_worker = softmax_sample(&worker_logits, temperature);
let best_logit = worker_logits[&best_worker_id]; let best_logit = worker_logits[&best_worker];
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
let best_overlap = *overlaps.get(&best_worker_id).unwrap_or(&0); // this is a runtime config set on a per worker basis, not per dp-rank
let total_blocks_info = workers let total_blocks_info = workers
.get(&best_worker_id) .get(&best_worker.worker_id)
.and_then(|cfg| cfg.as_ref()) .and_then(|cfg| cfg.as_ref())
.and_then(|cfg| cfg.total_kv_blocks) .and_then(|cfg| cfg.total_kv_blocks)
.map(|blocks| format!(", total blocks: {}", blocks)) .map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default(); .unwrap_or_default();
tracing::info!( tracing::info!(
"Selected worker: {}, logit: {:.3}, cached blocks: {}{}", "Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}{}",
best_worker_id, best_worker.worker_id,
best_worker.dp_rank,
best_logit, best_logit,
best_overlap, best_overlap,
total_blocks_info total_blocks_info
); );
Ok(WorkerSelectionResult { Ok(WorkerSelectionResult {
worker_id: best_worker_id, worker: best_worker,
required_blocks: request_blocks as u64, required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0), overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
}) })
} }
} }
...@@ -570,54 +587,61 @@ mod tests { ...@@ -570,54 +587,61 @@ mod tests {
fn test_softmax_sample_single_key() { fn test_softmax_sample_single_key() {
// Test that with a single key, softmax_sample always returns that key // Test that with a single key, softmax_sample always returns that key
let mut logits = HashMap::new(); let mut logits = HashMap::new();
let worker_id = 42; let worker = WorkerWithDpRank::from_worker_id(42);
logits.insert(worker_id, 0.5); // The value doesn't matter logits.insert(worker, 0.5); // The value doesn't matter
// Test with different temperatures // Test with different temperatures
for temperature in &[0.1, 1.0, 10.0] { for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature); let result = softmax_sample(&logits, *temperature);
assert_eq!(result, worker_id, "Should return the only available worker"); assert_eq!(result, worker, "Should return the only available worker");
} }
// Test with different logit values // Test with different logit values
logits.clear(); logits.clear();
logits.insert(worker_id, -100.0); // Very negative value logits.insert(worker, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker_id); assert_eq!(softmax_sample(&logits, 1.0), worker);
logits.clear(); logits.clear();
logits.insert(worker_id, 100.0); // Very positive value logits.insert(worker, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker_id); assert_eq!(softmax_sample(&logits, 1.0), worker);
logits.clear(); logits.clear();
logits.insert(worker_id, 0.0); // Zero value logits.insert(worker, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker_id); assert_eq!(softmax_sample(&logits, 1.0), worker);
} }
#[test] #[test]
fn test_softmax_sample_zero_temperature() { fn test_softmax_sample_zero_temperature() {
// Test that with temperature 0, softmax_sample returns the key with smallest logit // Test that with temperature 0, softmax_sample returns the key with smallest logit
let mut logits = HashMap::new(); let mut logits = HashMap::new();
logits.insert(1, 5.0); let worker1 = WorkerWithDpRank::from_worker_id(1);
logits.insert(2, 3.0); // This has the smallest logit let worker2 = WorkerWithDpRank::from_worker_id(2);
logits.insert(3, 7.0); let worker3 = WorkerWithDpRank::from_worker_id(3);
logits.insert(4, 3.5); let worker4 = WorkerWithDpRank::from_worker_id(4);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); // This has the smallest logit
logits.insert(worker3, 7.0);
logits.insert(worker4, 3.5);
// With temperature 0, should always return worker 2 (smallest logit) // With temperature 0, should always return worker 2 (smallest logit)
for _ in 0..10 { for _ in 0..10 {
let result = softmax_sample(&logits, 0.0); let result = softmax_sample(&logits, 0.0);
assert_eq!( assert_eq!(
result, 2, result, worker2,
"Should return worker with smallest logit when temperature is 0" "Should return worker with smallest logit when temperature is 0"
); );
} }
// Test with negative values // Test with negative values
logits.clear(); logits.clear();
logits.insert(10, -1.0); let worker10 = WorkerWithDpRank::from_worker_id(10);
logits.insert(20, -5.0); // This has the smallest logit let worker20 = WorkerWithDpRank::from_worker_id(20);
logits.insert(30, 0.0); let worker30 = WorkerWithDpRank::from_worker_id(30);
logits.insert(worker10, -1.0);
logits.insert(worker20, -5.0); // This has the smallest logit
logits.insert(worker30, 0.0);
let result = softmax_sample(&logits, 0.0); let result = softmax_sample(&logits, 0.0);
assert_eq!(result, 20, "Should handle negative logits correctly"); assert_eq!(result, worker20, "Should handle negative logits correctly");
} }
} }
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
//! requests share common prefixes (e.g., system prompts, few-shot examples). //! requests share common prefixes (e.g., system prompts, few-shot examples).
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::indexer::WorkerId;
use crate::tokens::SequenceHash; use crate::tokens::SequenceHash;
use anyhow::Result; use anyhow::Result;
use dashmap::DashMap; use dashmap::DashMap;
...@@ -39,8 +38,9 @@ use std::time::Duration; ...@@ -39,8 +38,9 @@ use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData}; use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, WorkerWithDpRank};
use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT; use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::CancellationToken; use dynamo_runtime::CancellationToken;
/// Duration after which stale requests are forcibly expired (5 minutes) /// Duration after which stale requests are forcibly expired (5 minutes)
...@@ -280,9 +280,9 @@ enum UpdateSequences { ...@@ -280,9 +280,9 @@ enum UpdateSequences {
/// Multi-worker extension of ActiveSequences that distributes requests across multiple threads /// Multi-worker extension of ActiveSequences that distributes requests across multiple threads
pub struct ActiveSequencesMultiWorker { pub struct ActiveSequencesMultiWorker {
senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>, senders: Arc<DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
request_to_worker: Arc<DashMap<RequestId, WorkerId>>, request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
handles: Arc<DashMap<WorkerId, std::thread::JoinHandle<()>>>, handles: Arc<DashMap<WorkerWithDpRank, std::thread::JoinHandle<()>>>,
block_size: usize, block_size: usize,
component: Component, component: Component,
router_id: Uuid, router_id: Uuid,
...@@ -293,7 +293,7 @@ impl ActiveSequencesMultiWorker { ...@@ -293,7 +293,7 @@ impl ActiveSequencesMultiWorker {
pub fn new( pub fn new(
component: Component, component: Component,
block_size: usize, block_size: usize,
worker_ids: Vec<WorkerId>, workers_with_configs: HashMap<i64, Option<ModelRuntimeConfig>>,
replica_sync: bool, replica_sync: bool,
router_uuid: String, router_uuid: String,
) -> Self { ) -> Self {
...@@ -311,12 +311,18 @@ impl ActiveSequencesMultiWorker { ...@@ -311,12 +311,18 @@ impl ActiveSequencesMultiWorker {
Uuid::new_v4() Uuid::new_v4()
}); });
for worker_id in worker_ids { // Expand workers by their dp_rank
// Create a child cancellation token from the component's runtime for (worker_id, config) in workers_with_configs {
let cancel_token = component.drt().runtime().child_token(); let dp_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1);
let (sender, handle) = Self::start_worker(block_size, cancel_token);
senders.insert(worker_id, sender); for dp_rank in 0..dp_size {
handles.insert(worker_id, handle); let worker = WorkerWithDpRank::new(worker_id, dp_rank);
// Create a child cancellation token from the component's runtime
let cancel_token = component.drt().runtime().child_token();
let (sender, handle) = Self::start_worker(block_size, cancel_token);
senders.insert(worker, sender);
handles.insert(worker, handle);
}
} }
let multi_worker = Self { let multi_worker = Self {
...@@ -458,8 +464,10 @@ impl ActiveSequencesMultiWorker { ...@@ -458,8 +464,10 @@ impl ActiveSequencesMultiWorker {
/// Background task to subscribe to active sequence events and update all workers /// Background task to subscribe to active sequence events and update all workers
async fn subscribe_to_events( async fn subscribe_to_events(
senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>, senders: Arc<
request_to_worker: Arc<DashMap<RequestId, WorkerId>>, DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>,
>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
component: Component, component: Component,
router_id: Uuid, router_id: Uuid,
cancel_token: CancellationToken, cancel_token: CancellationToken,
...@@ -496,9 +504,9 @@ impl ActiveSequencesMultiWorker { ...@@ -496,9 +504,9 @@ impl ActiveSequencesMultiWorker {
isl, isl,
overlap, overlap,
} => { } => {
request_to_worker.insert(event.request_id.clone(), event.worker_id); request_to_worker.insert(event.request_id.clone(), event.worker);
if let Some(sender) = senders.get(&event.worker_id) { if let Some(sender) = senders.get(&event.worker) {
// For replicated events, we create a dummy response channel since we don't need to handle expired requests // For replicated events, we create a dummy response channel since we don't need to handle expired requests
let (resp_tx, _) = tokio::sync::oneshot::channel(); let (resp_tx, _) = tokio::sync::oneshot::channel();
let _ = sender.send(UpdateSequences::AddRequest { let _ = sender.send(UpdateSequences::AddRequest {
...@@ -510,14 +518,14 @@ impl ActiveSequencesMultiWorker { ...@@ -510,14 +518,14 @@ impl ActiveSequencesMultiWorker {
}); });
} else { } else {
tracing::warn!( tracing::warn!(
"Worker {} not found, cannot process AddRequest", "Worker {:?} not found, cannot process AddRequest",
event.worker_id event.worker
); );
} }
} }
ActiveSequenceEventData::Free => { ActiveSequenceEventData::Free => {
if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id) if let Some((_, worker)) = request_to_worker.remove(&event.request_id)
&& let Some(sender) = senders.get(&worker_id) && let Some(sender) = senders.get(&worker)
{ {
let _ = sender.send(UpdateSequences::Free { let _ = sender.send(UpdateSequences::Free {
request_id: event.request_id.clone(), request_id: event.request_id.clone(),
...@@ -525,8 +533,8 @@ impl ActiveSequencesMultiWorker { ...@@ -525,8 +533,8 @@ impl ActiveSequencesMultiWorker {
} }
} }
ActiveSequenceEventData::MarkPrefillCompleted => { ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker_id) = request_to_worker.get(&event.request_id) if let Some(worker) = request_to_worker.get(&event.request_id)
&& let Some(sender) = senders.get(&*worker_id) && let Some(sender) = senders.get(&*worker)
{ {
let _ = sender.send(UpdateSequences::MarkPrefillCompleted { let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
request_id: event.request_id.clone(), request_id: event.request_id.clone(),
...@@ -547,41 +555,53 @@ impl ActiveSequencesMultiWorker { ...@@ -547,41 +555,53 @@ impl ActiveSequencesMultiWorker {
} }
/// Update the set of workers, adding and removing as needed /// Update the set of workers, adding and removing as needed
pub fn update_workers(&self, new_worker_ids: Vec<WorkerId>) { pub fn update_workers(
let current_workers: HashSet<WorkerId> = &self,
new_workers_with_configs: HashMap<i64, Option<ModelRuntimeConfig>>,
) {
let current_workers: HashSet<WorkerWithDpRank> =
self.senders.iter().map(|entry| *entry.key()).collect(); self.senders.iter().map(|entry| *entry.key()).collect();
let new_workers: HashSet<WorkerId> = new_worker_ids.into_iter().collect();
let workers_to_remove: Vec<WorkerId> = // Expand new workers by their dp_rank
let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (worker_id, config) in &new_workers_with_configs {
let dp_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1);
for dp_rank in 0..dp_size {
new_workers.insert(WorkerWithDpRank::new(*worker_id, dp_rank));
}
}
let workers_to_remove: Vec<WorkerWithDpRank> =
current_workers.difference(&new_workers).copied().collect(); current_workers.difference(&new_workers).copied().collect();
let workers_to_add: Vec<WorkerId> = let workers_to_add: Vec<WorkerWithDpRank> =
new_workers.difference(&current_workers).copied().collect(); new_workers.difference(&current_workers).copied().collect();
// Remove workers // Remove workers (this will naturally remove all dp ranks for a worker_id)
for worker_id in &workers_to_remove { for worker in &workers_to_remove {
tracing::warn!("Removing worker {}", worker_id); tracing::warn!("Removing worker {:?}", worker);
// Send shutdown command to the worker // Send shutdown command to the worker
if let Some((_, sender)) = self.senders.remove(worker_id) { if let Some((_, sender)) = self.senders.remove(worker) {
let _ = sender.send(UpdateSequences::Shutdown); let _ = sender.send(UpdateSequences::Shutdown);
} }
self.handles.remove(worker_id); self.handles.remove(worker);
// Clean up request_to_worker mappings for this worker // Clean up request_to_worker mappings for this worker
self.request_to_worker self.request_to_worker
.retain(|_request_id, mapped_worker_id| *mapped_worker_id != *worker_id); .retain(|_request_id, mapped_worker| mapped_worker != worker);
} }
// Add new workers // Add new workers
for worker_id in &workers_to_add { for worker in &workers_to_add {
tracing::warn!("Adding worker {}", worker_id); tracing::warn!("Adding worker {:?}", worker);
let (sender, handle) = Self::start_worker( let (sender, handle) = Self::start_worker(
self.block_size, self.block_size,
self.component.drt().runtime().child_token(), self.component.drt().runtime().child_token(),
); );
self.senders.insert(*worker_id, sender); self.senders.insert(*worker, sender);
self.handles.insert(*worker_id, handle); self.handles.insert(*worker, handle);
} }
} }
...@@ -591,10 +611,10 @@ impl ActiveSequencesMultiWorker { ...@@ -591,10 +611,10 @@ impl ActiveSequencesMultiWorker {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
worker_id: WorkerId, worker: WorkerWithDpRank,
) -> Result<()> { ) -> Result<()> {
if !self.senders.contains_key(&worker_id) { if !self.senders.contains_key(&worker) {
return Err(anyhow::anyhow!("Worker ID {worker_id} not found")); return Err(anyhow::anyhow!("Worker {:?} not found", worker));
} }
// Create response channel // Create response channel
...@@ -604,7 +624,7 @@ impl ActiveSequencesMultiWorker { ...@@ -604,7 +624,7 @@ impl ActiveSequencesMultiWorker {
if self.replica_sync { if self.replica_sync {
let event = ActiveSequenceEvent { let event = ActiveSequenceEvent {
request_id: request_id.clone(), request_id: request_id.clone(),
worker_id, worker,
data: ActiveSequenceEventData::AddRequest { data: ActiveSequenceEventData::AddRequest {
token_sequence: token_sequence.clone(), token_sequence: token_sequence.clone(),
isl, isl,
...@@ -617,11 +637,11 @@ impl ActiveSequencesMultiWorker { ...@@ -617,11 +637,11 @@ impl ActiveSequencesMultiWorker {
.await?; .await?;
} }
// Update local state // Update local state with full WorkerWithDpRank
self.request_to_worker.insert(request_id.clone(), worker_id); self.request_to_worker.insert(request_id.clone(), worker);
self.senders self.senders
.get(&worker_id) .get(&worker)
.unwrap() .unwrap()
.send(UpdateSequences::AddRequest { .send(UpdateSequences::AddRequest {
request_id, request_id,
...@@ -646,7 +666,7 @@ impl ActiveSequencesMultiWorker { ...@@ -646,7 +666,7 @@ impl ActiveSequencesMultiWorker {
} }
pub async fn free(&self, request_id: &RequestId) -> Result<()> { pub async fn free(&self, request_id: &RequestId) -> Result<()> {
let worker_id = self let worker = self
.request_to_worker .request_to_worker
.get(request_id) .get(request_id)
.map(|entry| *entry) .map(|entry| *entry)
...@@ -656,7 +676,7 @@ impl ActiveSequencesMultiWorker { ...@@ -656,7 +676,7 @@ impl ActiveSequencesMultiWorker {
if self.replica_sync { if self.replica_sync {
let event = ActiveSequenceEvent { let event = ActiveSequenceEvent {
request_id: request_id.clone(), request_id: request_id.clone(),
worker_id, worker,
data: ActiveSequenceEventData::Free, data: ActiveSequenceEventData::Free,
router_id: self.router_id, router_id: self.router_id,
}; };
...@@ -667,7 +687,7 @@ impl ActiveSequencesMultiWorker { ...@@ -667,7 +687,7 @@ impl ActiveSequencesMultiWorker {
// Update local state // Update local state
self.senders self.senders
.get(&worker_id) .get(&worker)
.unwrap() .unwrap()
.send(UpdateSequences::Free { .send(UpdateSequences::Free {
request_id: request_id.clone(), request_id: request_id.clone(),
...@@ -681,7 +701,7 @@ impl ActiveSequencesMultiWorker { ...@@ -681,7 +701,7 @@ impl ActiveSequencesMultiWorker {
/// Mark prefill as completed for a request /// Mark prefill as completed for a request
pub async fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<()> { pub async fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<()> {
let worker_id = self let worker = self
.request_to_worker .request_to_worker
.get(request_id) .get(request_id)
.map(|entry| *entry) .map(|entry| *entry)
...@@ -691,7 +711,7 @@ impl ActiveSequencesMultiWorker { ...@@ -691,7 +711,7 @@ impl ActiveSequencesMultiWorker {
if self.replica_sync { if self.replica_sync {
let event = ActiveSequenceEvent { let event = ActiveSequenceEvent {
request_id: request_id.clone(), request_id: request_id.clone(),
worker_id, worker,
data: ActiveSequenceEventData::MarkPrefillCompleted, data: ActiveSequenceEventData::MarkPrefillCompleted,
router_id: self.router_id, router_id: self.router_id,
}; };
...@@ -702,7 +722,7 @@ impl ActiveSequencesMultiWorker { ...@@ -702,7 +722,7 @@ impl ActiveSequencesMultiWorker {
// Update local state // Update local state
self.senders self.senders
.get(&worker_id) .get(&worker)
.unwrap() .unwrap()
.send(UpdateSequences::MarkPrefillCompleted { .send(UpdateSequences::MarkPrefillCompleted {
request_id: request_id.clone(), request_id: request_id.clone(),
...@@ -727,33 +747,33 @@ impl ActiveSequencesMultiWorker { ...@@ -727,33 +747,33 @@ impl ActiveSequencesMultiWorker {
Option<Arc<Vec<SequenceHash>>>, Option<Arc<Vec<SequenceHash>>>,
tokio::sync::oneshot::Sender<T>, tokio::sync::oneshot::Sender<T>,
) -> UpdateSequences, ) -> UpdateSequences,
) -> HashMap<WorkerId, T> { ) -> HashMap<WorkerWithDpRank, T> {
let mut results = HashMap::new(); let mut results = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new); let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new(); let mut receivers = Vec::new();
// Send queries to all workers in parallel // Send queries to all workers in parallel
for entry in self.senders.iter() { for entry in self.senders.iter() {
let worker_id = *entry.key(); let worker = *entry.key();
let sender = entry.value(); let sender = entry.value();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
receivers.push((worker_id, resp_rx)); receivers.push((worker, resp_rx));
if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) { if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) {
tracing::error!("Failed to send command to worker {}: {}", worker_id, e); tracing::error!("Failed to send command to worker {:?}: {}", worker, e);
} }
} }
// Collect results from all workers // Collect results from all workers
for (worker_id, receiver) in receivers { for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await { match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
Ok(Ok(result)) => { Ok(Ok(result)) => {
results.insert(worker_id, result); results.insert(worker, result);
} }
Ok(Err(_)) => { Ok(Err(_)) => {
tracing::error!("Worker {} dropped response channel", worker_id); tracing::error!("Worker {:?} dropped response channel", worker);
} }
Err(_) => { Err(_) => {
tracing::error!("Timeout waiting for response from worker {}", worker_id); tracing::error!("Timeout waiting for response from worker {:?}", worker);
} }
} }
} }
...@@ -762,7 +782,10 @@ impl ActiveSequencesMultiWorker { ...@@ -762,7 +782,10 @@ impl ActiveSequencesMultiWorker {
} }
/// Query all workers for the number of new blocks that would be added by a token sequence /// Query all workers for the number of new blocks that would be added by a token sequence
pub async fn new_blocks(&self, token_sequence: Vec<SequenceHash>) -> HashMap<WorkerId, usize> { pub async fn new_blocks(
&self,
token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::NewBlocks { Some(ts) => UpdateSequences::NewBlocks {
token_sequence: ts, token_sequence: ts,
...@@ -777,7 +800,7 @@ impl ActiveSequencesMultiWorker { ...@@ -777,7 +800,7 @@ impl ActiveSequencesMultiWorker {
pub async fn potential_blocks( pub async fn potential_blocks(
&self, &self,
token_sequence: Vec<SequenceHash>, token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerId, usize> { ) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::PotentialBlocks { Some(ts) => UpdateSequences::PotentialBlocks {
token_sequence: ts, token_sequence: ts,
...@@ -794,45 +817,49 @@ impl ActiveSequencesMultiWorker { ...@@ -794,45 +817,49 @@ impl ActiveSequencesMultiWorker {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
overlaps: OverlapScores, overlaps: OverlapScores,
) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) { ) -> (
HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>,
) {
let mut potential_blocks = HashMap::new(); let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new(); let mut potential_tokens = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new); let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new(); let mut receivers = Vec::new();
// Send queries to all workers in parallel // Iterate through overlaps to process each WorkerWithDpRank
for entry in self.senders.iter() { for (worker, overlap) in overlaps.scores.iter() {
let worker_id = *entry.key(); // Check if the worker has a sender
let sender = entry.value(); if let Some(sender) = self.senders.get(worker) {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
receivers.push((worker_id, resp_rx)); receivers.push((*worker, resp_rx));
if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens { if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens {
token_sequence: token_sequence_shared.clone(), token_sequence: token_sequence_shared.clone(),
isl, isl,
overlap: overlaps.scores.get(&worker_id).copied().unwrap_or(0), overlap: *overlap,
resp_tx, resp_tx,
}) { }) {
tracing::error!( tracing::error!(
"Failed to send potential_tokens command to worker {}: {}", "Failed to send potential_tokens command to worker {:?}: {}",
worker_id, worker,
e e
); );
}
} }
} }
// Collect results from all workers // Collect results from all workers
for (worker_id, receiver) in receivers { for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await { match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
Ok(Ok((blocks, tokens))) => { Ok(Ok((blocks, tokens))) => {
potential_blocks.insert(worker_id, blocks); potential_blocks.insert(worker, blocks);
potential_tokens.insert(worker_id, tokens); potential_tokens.insert(worker, tokens);
} }
Ok(Err(_)) => { Ok(Err(_)) => {
tracing::error!("Worker {} dropped response channel", worker_id); tracing::error!("Worker {:?} dropped response channel", worker);
} }
Err(_) => { Err(_) => {
tracing::error!("Timeout waiting for response from worker {}", worker_id); tracing::error!("Timeout waiting for response from worker {:?}", worker);
} }
} }
} }
...@@ -841,13 +868,13 @@ impl ActiveSequencesMultiWorker { ...@@ -841,13 +868,13 @@ impl ActiveSequencesMultiWorker {
} }
/// Query all workers for their current number of active blocks /// Query all workers for their current number of active blocks
pub async fn active_blocks(&self) -> HashMap<WorkerId, usize> { pub async fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx }) self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
.await .await
} }
/// Query all workers for their current number of active tokens /// Query all workers for their current number of active tokens
pub async fn active_tokens(&self) -> HashMap<WorkerId, usize> { pub async fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx }) self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx })
.await .await
} }
...@@ -918,20 +945,33 @@ mod tests { ...@@ -918,20 +945,33 @@ mod tests {
.create() .create()
.await?; .await?;
// Create multi-worker sequence managers with ALL workers [0, 1, 2] // Create multi-worker sequence managers with:
// Both use the same component to ensure event synchronization works // - Worker 0 with dp_size=2 (dp_ranks 0 and 1)
let worker_ids = vec![0, 1, 2]; // - Worker 1 with dp_size=1 (dp_rank 0)
// This gives us 3 effective workers total to test dp_rank effect
// Both seq_managers use the same component to ensure event synchronization works
let mut workers_with_configs = HashMap::new();
// Create runtime config for worker 0 with dp_size=2
let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
config_worker_0.data_parallel_size = 2;
workers_with_configs.insert(0, Some(config_worker_0));
// Create runtime config for worker 1 with dp_size=1 (default)
let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
workers_with_configs.insert(1, Some(config_worker_1));
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new( let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(), component.clone(),
block_size, block_size,
worker_ids.clone(), workers_with_configs.clone(),
true, true,
Uuid::new_v4().to_string(), Uuid::new_v4().to_string(),
)); ));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new( let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component, component,
block_size, block_size,
worker_ids, workers_with_configs,
true, true,
Uuid::new_v4().to_string(), Uuid::new_v4().to_string(),
)); ));
...@@ -941,36 +981,36 @@ mod tests { ...@@ -941,36 +981,36 @@ mod tests {
// PHASE 1: Add requests using both seq_manager_1 and seq_manager_2 // PHASE 1: Add requests using both seq_manager_1 and seq_manager_2
// Add request_0 to worker 0: sequence [0, 1, 2] // Add request_0 to worker 0, dp_rank 0: sequence [0, 1, 2]
seq_manager_1 seq_manager_1
.add_request( .add_request(
"request_0".to_string(), "request_0".to_string(),
Some(vec![0, 1, 2]), Some(vec![0, 1, 2]),
12, // ISL (3 blocks * 4 block_size) 12, // ISL (3 blocks * 4 block_size)
0, // no overlap 0, // no overlap
0, // worker_id WorkerWithDpRank::new(0, 0),
) )
.await?; .await?;
// Add request_1 to worker 1: sequence [3, 4] // Add request_1 to worker 0, dp_rank 1: sequence [3, 4]
seq_manager_1 seq_manager_1
.add_request( .add_request(
"request_1".to_string(), "request_1".to_string(),
Some(vec![3, 4]), Some(vec![3, 4]),
8, // ISL (2 blocks * 4 block_size) 8, // ISL (2 blocks * 4 block_size)
0, // no overlap 0, // no overlap
1, // worker_id WorkerWithDpRank::new(0, 1),
) )
.await?; .await?;
// Add request_2 to worker 2: sequence [0, 1, 2, 3] using seq_manager_2 // Add request_2 to worker 1, dp_rank 0: sequence [0, 1, 2, 3] using seq_manager_2
seq_manager_2 seq_manager_2
.add_request( .add_request(
"request_2".to_string(), "request_2".to_string(),
Some(vec![0, 1, 2, 3]), Some(vec![0, 1, 2, 3]),
16, // ISL (4 blocks * 4 block_size) 16, // ISL (4 blocks * 4 block_size)
0, // no overlap 0, // no overlap
2, // worker_id WorkerWithDpRank::new(1, 0),
) )
.await?; .await?;
...@@ -981,27 +1021,38 @@ mod tests { ...@@ -981,27 +1021,38 @@ mod tests {
let blocks_phase1 = seq_manager_1.active_blocks().await; let blocks_phase1 = seq_manager_1.active_blocks().await;
let tokens_phase1 = seq_manager_1.active_tokens().await; let tokens_phase1 = seq_manager_1.active_tokens().await;
// Verify that seq_manager_1 sees all requests including request_2 from thread 2 // Verify that seq_manager_1 sees all requests including request_2 from seq_manager_2
// We now have:
// - Worker 0, dp_rank 0: request_0
// - Worker 0, dp_rank 1: request_1
// - Worker 1, dp_rank 0: request_2
let worker_0_dp0 = WorkerWithDpRank::new(0, 0);
let worker_0_dp1 = WorkerWithDpRank::new(0, 1);
let worker_1_dp0 = WorkerWithDpRank::new(1, 0);
assert_eq!( assert_eq!(
blocks_phase1[&0], 3, blocks_phase1[&worker_0_dp0], 3,
"Worker 0 should have 3 active blocks (from request_0)" "Worker 0 dp_rank 0 should have 3 active blocks (from request_0)"
); );
assert_eq!( assert_eq!(
blocks_phase1[&1], 2, blocks_phase1[&worker_0_dp1], 2,
"Worker 1 should have 2 active blocks (from request_1)" "Worker 0 dp_rank 1 should have 2 active blocks (from request_1)"
); );
assert_eq!( assert_eq!(
blocks_phase1[&2], 4, blocks_phase1[&worker_1_dp0], 4,
"Worker 2 should have 4 active blocks (from request_2 added by seq_manager_2)" "Worker 1 dp_rank 0 should have 4 active blocks (from request_2 added by seq_manager_2)"
); );
assert_eq!( assert_eq!(
tokens_phase1[&0], 12, tokens_phase1[&worker_0_dp0], 12,
"Worker 0 should have 12 active tokens" "Worker 0 dp_rank 0 should have 12 active tokens"
); );
assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
assert_eq!( assert_eq!(
tokens_phase1[&2], 16, tokens_phase1[&worker_0_dp1], 8,
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)" "Worker 0 dp_rank 1 should have 8 active tokens"
);
assert_eq!(
tokens_phase1[&worker_1_dp0], 16,
"Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
); );
// PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2 // PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2
...@@ -1020,17 +1071,23 @@ mod tests { ...@@ -1020,17 +1071,23 @@ mod tests {
let blocks_phase2 = seq_manager_2.active_blocks().await; let blocks_phase2 = seq_manager_2.active_blocks().await;
let tokens_phase2 = seq_manager_2.active_tokens().await; let tokens_phase2 = seq_manager_2.active_tokens().await;
// Verify phase 2 results - everything should be empty // Verify phase 2 results - everything should be empty for all 3 workers
for worker_id in 0..=2 { let all_workers = vec![
WorkerWithDpRank::new(0, 0),
WorkerWithDpRank::new(0, 1),
WorkerWithDpRank::new(1, 0),
];
for worker in all_workers {
assert_eq!( assert_eq!(
blocks_phase2[&worker_id], 0, blocks_phase2[&worker], 0,
"Worker {} should have 0 active blocks after all requests freed", "Worker (id={}, dp_rank={}) should have 0 active blocks after all requests freed",
worker_id worker.worker_id, worker.dp_rank
); );
assert_eq!( assert_eq!(
tokens_phase2[&worker_id], 0, tokens_phase2[&worker], 0,
"Worker {} should have 0 active tokens after all requests freed", "Worker (id={}, dp_rank={}) should have 0 active tokens after all requests freed",
worker_id worker.worker_id, worker.dp_rank
); );
} }
...@@ -1059,18 +1116,22 @@ mod tests { ...@@ -1059,18 +1116,22 @@ mod tests {
// Create multi-worker sequence managers with ALL workers [0, 1, 2] // Create multi-worker sequence managers with ALL workers [0, 1, 2]
// Both use the same component to ensure event synchronization works // Both use the same component to ensure event synchronization works
let worker_ids = vec![0, 1, 2]; let mut workers_with_configs = HashMap::new();
workers_with_configs.insert(0, None);
workers_with_configs.insert(1, None);
workers_with_configs.insert(2, None);
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new( let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(), component.clone(),
block_size, block_size,
worker_ids.clone(), workers_with_configs.clone(),
true, true,
Uuid::new_v4().to_string(), Uuid::new_v4().to_string(),
)); ));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new( let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component, component,
block_size, block_size,
worker_ids, workers_with_configs,
true, true,
Uuid::new_v4().to_string(), Uuid::new_v4().to_string(),
)); ));
...@@ -1087,7 +1148,7 @@ mod tests { ...@@ -1087,7 +1148,7 @@ mod tests {
None, // No token sequence None, // No token sequence
12, // ISL (12 tokens) 12, // ISL (12 tokens)
0, // no overlap 0, // no overlap
0, // worker_id WorkerWithDpRank::from_worker_id(0),
) )
.await?; .await?;
...@@ -1098,7 +1159,7 @@ mod tests { ...@@ -1098,7 +1159,7 @@ mod tests {
None, // No token sequence None, // No token sequence
8, // ISL (8 tokens) 8, // ISL (8 tokens)
0, // no overlap 0, // no overlap
1, // worker_id WorkerWithDpRank::from_worker_id(1),
) )
.await?; .await?;
...@@ -1109,7 +1170,7 @@ mod tests { ...@@ -1109,7 +1170,7 @@ mod tests {
None, // No token sequence None, // No token sequence
16, // ISL (16 tokens) 16, // ISL (16 tokens)
0, // no overlap 0, // no overlap
2, // worker_id WorkerWithDpRank::from_worker_id(2),
) )
.await?; .await?;
...@@ -1120,13 +1181,20 @@ mod tests { ...@@ -1120,13 +1181,20 @@ mod tests {
let tokens_phase1 = seq_manager_1.active_tokens().await; let tokens_phase1 = seq_manager_1.active_tokens().await;
// Verify that seq_manager_1 sees all requests including request_2 from thread 2 // Verify that seq_manager_1 sees all requests including request_2 from thread 2
let worker_0 = WorkerWithDpRank::from_worker_id(0);
let worker_1 = WorkerWithDpRank::from_worker_id(1);
let worker_2 = WorkerWithDpRank::from_worker_id(2);
assert_eq!( assert_eq!(
tokens_phase1[&0], 12, tokens_phase1[&worker_0], 12,
"Worker 0 should have 12 active tokens" "Worker 0 should have 12 active tokens"
); );
assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
assert_eq!( assert_eq!(
tokens_phase1[&2], 16, tokens_phase1[&worker_1], 8,
"Worker 1 should have 8 active tokens"
);
assert_eq!(
tokens_phase1[&worker_2], 16,
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)" "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
); );
...@@ -1156,8 +1224,9 @@ mod tests { ...@@ -1156,8 +1224,9 @@ mod tests {
// Verify phase 2 results - everything should be empty // Verify phase 2 results - everything should be empty
for worker_id in 0..=2 { for worker_id in 0..=2 {
let worker = WorkerWithDpRank::from_worker_id(worker_id);
assert_eq!( assert_eq!(
tokens_phase2[&worker_id], 0, tokens_phase2[&worker], 0,
"Worker {} should have 0 active tokens after all requests freed", "Worker {} should have 0 active tokens after all requests freed",
worker_id worker_id
); );
......
...@@ -23,7 +23,8 @@ use crate::{ ...@@ -23,7 +23,8 @@ use crate::{
kv_router::{ kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK, KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK,
ROUTER_SNAPSHOT_LOCK, ROUTER_SNAPSHOT_LOCK,
indexer::{DumpRequest, GetWorkersRequest, RouterEvent, WorkerId}, indexer::{DumpRequest, GetWorkersRequest, RouterEvent},
protocols::WorkerId,
}, },
}; };
......
...@@ -211,11 +211,26 @@ impl LocalModelBuilder { ...@@ -211,11 +211,26 @@ impl LocalModelBuilder {
.map(RequestTemplate::load) .map(RequestTemplate::load)
.transpose()?; .transpose()?;
// Override runtime configs with mocker engine args (applies to both paths)
if self.is_mocker
&& let Some(path) = &self.extra_engine_args
{
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.kv_cache_block_size = mocker_engine_args.block_size as u32;
self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
}
// frontend and echo engine don't need a path. // frontend and echo engine don't need a path.
if self.model_path.is_none() { if self.model_path.is_none() {
let mut card = ModelDeploymentCard::with_name_only( let mut card = ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME), self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
); );
card.kv_cache_block_size = self.kv_cache_block_size;
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take(); card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone(); card.runtime_config = self.runtime_config.clone();
...@@ -266,18 +281,6 @@ impl LocalModelBuilder { ...@@ -266,18 +281,6 @@ impl LocalModelBuilder {
card.context_length = context_length; card.context_length = context_length;
} }
// Override runtime configs with mocker engine args
if self.is_mocker
&& let Some(path) = &self.extra_engine_args
{
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
}
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take(); card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone(); card.runtime_config = self.runtime_config.clone();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment