Unverified Commit f978f4d1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: dp rank routing (#3597)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 29f5b822
......@@ -7,6 +7,7 @@
NUM_WORKERS=8
MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
TENSOR_PARALLEL_SIZE=1
DATA_PARALLEL_SIZE=1
USE_MOCKERS=false
USE_TRTLLM=false
MODE="agg" # Options: agg (default), decode, prefill
......@@ -28,6 +29,10 @@ while [[ $# -gt 0 ]]; do
TENSOR_PARALLEL_SIZE="$2"
shift 2
;;
--data-parallel-size)
DATA_PARALLEL_SIZE="$2"
shift 2
;;
--mockers)
USE_MOCKERS=true
shift
......@@ -114,13 +119,19 @@ if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt
exit 1
fi
if ! [[ "$DATA_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$DATA_PARALLEL_SIZE" -lt 1 ]; then
echo "Error: DATA_PARALLEL_SIZE must be a positive integer"
exit 1
fi
if ! [[ "$BASE_GPU_OFFSET" =~ ^[0-9]+$ ]]; then
echo "Error: BASE_GPU_OFFSET must be a non-negative integer"
exit 1
fi
# Calculate total GPUs needed
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE))
# Calculate total GPUs needed (TP * DP per worker)
GPUS_PER_WORKER=$((TENSOR_PARALLEL_SIZE * DATA_PARALLEL_SIZE))
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * GPUS_PER_WORKER))
LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1))
echo "Configuration:"
if [ "$USE_MOCKERS" = true ]; then
......@@ -135,6 +146,8 @@ echo " Mode: $MODE"
echo " Workers: $NUM_WORKERS"
echo " Model: $MODEL_PATH"
echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
echo " Data Parallel Size: $DATA_PARALLEL_SIZE"
echo " GPUs per worker: $GPUS_PER_WORKER"
echo " Total GPUs needed: $TOTAL_GPUS_NEEDED"
echo " GPU Range: $BASE_GPU_OFFSET-$LAST_GPU"
echo " Engine args: ${EXTRA_ARGS[*]}"
......@@ -155,14 +168,16 @@ echo "Starting $NUM_WORKERS $MODE workers..."
for i in $(seq 1 $NUM_WORKERS); do
{
echo "[${MODE^} Worker-$i] Starting..."
MODE_CAPITALIZED=$(echo "$MODE" | sed 's/\(.\)/\U\1/')
echo "[$MODE_CAPITALIZED Worker-$i] Starting..."
# Calculate GPU indices for this worker (with base offset)
START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * TENSOR_PARALLEL_SIZE ))
END_GPU=$(( START_GPU + TENSOR_PARALLEL_SIZE - 1 ))
# Each worker needs TP * DP GPUs
START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * GPUS_PER_WORKER ))
END_GPU=$(( START_GPU + GPUS_PER_WORKER - 1 ))
# Build CUDA_VISIBLE_DEVICES string
if [ "$TENSOR_PARALLEL_SIZE" -eq 1 ]; then
# Build CUDA_VISIBLE_DEVICES string for all GPUs (TP * DP)
if [ "$GPUS_PER_WORKER" -eq 1 ]; then
GPU_DEVICES="$START_GPU"
else
GPU_DEVICES=""
......@@ -177,12 +192,17 @@ for i in $(seq 1 $NUM_WORKERS); do
if [ "$USE_MOCKERS" = true ]; then
# Run mocker engine (no GPU assignment needed)
exec python -m dynamo.mocker \
--model-path "$MODEL_PATH" \
--endpoint dyn://test.mocker.generate \
"${EXTRA_ARGS[@]}"
MOCKER_ARGS=()
MOCKER_ARGS+=("--model-path" "$MODEL_PATH")
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
MOCKER_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
MOCKER_ARGS+=("${EXTRA_ARGS[@]}")
exec python -m dynamo.mocker "${MOCKER_ARGS[@]}"
elif [ "$USE_TRTLLM" = true ]; then
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES"
echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
# Run TensorRT-LLM engine with trtllm-llmapi-launch for proper initialization
TRTLLM_ARGS=()
TRTLLM_ARGS+=("--model-path" "$MODEL_PATH")
......@@ -195,11 +215,14 @@ for i in $(seq 1 $NUM_WORKERS); do
exec env CUDA_VISIBLE_DEVICES=$GPU_DEVICES trtllm-llmapi-launch python -m dynamo.trtllm \
"${TRTLLM_ARGS[@]}"
else
echo "[${MODE^} Worker-$i] Using GPUs: $GPU_DEVICES"
echo "[$MODE_CAPITALIZED Worker-$i] Using GPUs: $GPU_DEVICES"
# Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing
VLLM_ARGS=()
VLLM_ARGS+=("--model" "$MODEL_PATH")
VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
VLLM_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
if [ "$MODE" = "prefill" ]; then
VLLM_ARGS+=("--is-prefill-worker")
fi
......
......@@ -99,6 +99,7 @@ class StandaloneRouterHandler:
"eos_token_ids": request.get("eos_token_ids", []),
"annotations": request.get("annotations", []),
"disaggregated_params": request.get("disaggregated_params"),
"dp_rank": request.get("dp_rank"),
"extra_args": request.get("extra_args", {}),
}
......
......@@ -33,7 +33,7 @@ class BaseWorkerHandler(ABC):
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.kv_publisher = None
self.kv_publishers = None
self.engine_monitor = VllmEngineMonitor(runtime, engine)
@abstractmethod
......@@ -81,9 +81,16 @@ class BaseWorkerHandler(ABC):
"""Override in subclasses if cleanup is needed."""
pass
async def generate_tokens(self, prompt, sampling_params, request_id):
async def generate_tokens(
self, prompt, sampling_params, request_id, data_parallel_rank=None
):
try:
gen = self.engine_client.generate(prompt, sampling_params, request_id)
gen = self.engine_client.generate(
prompt,
sampling_params,
request_id,
data_parallel_rank=data_parallel_rank,
)
num_output_tokens_so_far = 0
try:
......@@ -211,10 +218,12 @@ class DecodeWorkerHandler(BaseWorkerHandler):
return
logger.warning(f"Prefill error: {e}, falling back to local prefill")
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id):
try:
async for tok in self.generate_tokens(
prompt, sampling_params, request_id
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
):
yield tok
except EngineDeadError as e:
......@@ -241,9 +250,13 @@ class PrefillWorkerHandler(BaseWorkerHandler):
sampling_params_dict = extra_args.get("sampling_params", {})
sampling_params = msgspec.convert(sampling_params_dict, SamplingParams)
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id, is_prefill=True):
try:
gen = self.engine_client.generate(prompt, sampling_params, request_id)
gen = self.engine_client.generate(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
)
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
......
......@@ -5,7 +5,6 @@ import asyncio
import logging
import os
import signal
from typing import Optional
import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher
......@@ -107,33 +106,41 @@ def setup_kv_event_publisher(
component,
generate_endpoint,
vllm_config,
) -> Optional[ZmqKvEventPublisher]:
):
"""
Set up KV event publisher for prefix caching if enabled.
Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Returns:
ZmqKvEventPublisher if prefix caching is enabled, None otherwise.
List of ZmqKvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
"""
if not config.engine_args.enable_prefix_caching:
return None
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=config.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
# Get data_parallel_size to create publishers for all dp_ranks
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
kv_publishers = []
for dp_rank in range(data_parallel_size):
# Each dp_rank publishes to a different port
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=dp_rank,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
kv_publishers.append(kv_publisher)
logger.info(f"Worker reading KV events from {zmq_endpoint}")
logger.info(
f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
)
return kv_publisher
return kv_publishers if kv_publishers else None
def setup_vllm_engine(config, stat_logger=None):
......@@ -200,12 +207,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
runtime, component, engine_client, default_sampling_params
)
# Set up KV event publisher for prefix caching if enabled
kv_publisher = setup_kv_event_publisher(
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
if kv_publishers:
handler.kv_publishers = kv_publishers
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
......@@ -285,12 +292,12 @@ async def init(runtime: DistributedRuntime, config: Config):
prefill_router_client,
)
# Set up KV event publisher for prefix caching if enabled
kv_publisher = setup_kv_event_publisher(
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
if kv_publishers:
handler.kv_publishers = kv_publishers
if config.engine_args.disable_log_stats is False:
from prometheus_client import REGISTRY
......@@ -311,6 +318,12 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(
vllm_config.parallel_config, "data_parallel_size", 1
)
runtime_config.data_parallel_size = data_parallel_size
await register_llm(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
......
......@@ -322,7 +322,11 @@ The `KvPushRouter` provides the following methods:
- **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.
- **`best_worker_id(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, overlap_blocks)`.
- **`best_worker(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, dp_rank, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
- **`best_worker_id(token_ids, router_config_override=None, request_id=None)`**: **[DEPRECATED - use `best_worker()` instead]** Query which worker would be selected for given tokens. Returns `(worker_id, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
......
......@@ -207,6 +207,7 @@ fn kv_event_create_stored_from_parts(
parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
}),
event_id: kv_params.event_id,
dp_rank: 0,
}
}
......@@ -224,6 +225,7 @@ fn kv_event_create_removed_from_parts(
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: 0,
}
}
......
......@@ -101,7 +101,7 @@ impl WorkerMetricsPublisher {
#[derive(Clone)]
pub struct ZmqKvEventPublisherConfig {
#[pyo3(get, set)]
pub worker_id: i64,
pub worker_id: WorkerId,
#[pyo3(get, set)]
pub kv_block_size: usize,
#[pyo3(get, set)]
......@@ -120,7 +120,7 @@ impl ZmqKvEventPublisherConfig {
zmq_topic = "".to_string()
))]
pub fn new(
worker_id: i64,
worker_id: WorkerId,
kv_block_size: usize,
zmq_endpoint: String,
zmq_topic: String,
......@@ -234,13 +234,20 @@ impl Drop for ZmqKvEventListener {
pub(crate) struct KvEventPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
kv_block_size: usize,
dp_rank: DpRank,
warning_count: Arc<AtomicU32>,
}
#[pymethods]
impl KvEventPublisher {
#[new]
fn new(component: Component, worker_id: i64, kv_block_size: usize) -> PyResult<Self> {
#[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))]
fn new(
component: Component,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
......@@ -256,6 +263,7 @@ impl KvEventPublisher {
Ok(Self {
inner: inner.into(),
kv_block_size,
dp_rank,
warning_count: Arc::new(AtomicU32::new(0)),
})
}
......@@ -286,6 +294,7 @@ impl KvEventPublisher {
&self.warning_count,
),
}),
dp_rank: self.dp_rank,
};
self.inner.publish(event).map_err(to_pyerr)
......@@ -299,6 +308,7 @@ impl KvEventPublisher {
let event = KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: self.dp_rank,
};
self.inner.publish(event).map_err(to_pyerr)
......@@ -314,8 +324,13 @@ pub(crate) struct OverlapScores {
#[pymethods]
impl OverlapScores {
#[getter]
fn scores(&self) -> HashMap<llm_rs::kv_router::indexer::WorkerId, u32> {
self.inner.scores.clone()
fn scores(&self) -> HashMap<(i64, u32), u32> {
// Return scores with full WorkerWithDpRank granularity as (worker_id, dp_rank) tuples
self.inner
.scores
.iter()
.map(|(worker, score)| ((worker.worker_id, worker.dp_rank), *score))
.collect()
}
#[getter]
......@@ -361,7 +376,7 @@ impl RadixTree {
fn apply_event(
&mut self,
_py: Python,
worker_id: i64,
worker_id: WorkerId,
kv_cache_event_bytes: &[u8],
) -> PyResult<()> {
let kv_cache_event: llm_rs::kv_router::protocols::KvCacheEvent =
......@@ -377,12 +392,12 @@ impl RadixTree {
Ok(())
}
fn remove_worker(&mut self, _py: Python, worker_id: i64) -> PyResult<()> {
fn remove_worker(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.remove_worker(worker_id);
Ok(())
}
fn clear_all_blocks(&mut self, _py: Python, worker_id: i64) -> PyResult<()> {
fn clear_all_blocks(&mut self, _py: Python, worker_id: WorkerId) -> PyResult<()> {
self.inner.clear_all_blocks(worker_id);
Ok(())
}
......@@ -517,16 +532,19 @@ impl ApproxKvIndexer {
})
}
#[pyo3(signature = (tokens, worker_id, dp_rank=0))]
fn process_routing_decision_for_request<'p>(
&self,
py: Python<'p>,
tokens: Vec<u32>,
worker_id: i64,
worker_id: WorkerId,
dp_rank: DpRank,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let worker = llm_rs::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank);
indexer
.process_routing_decision_for_request(tokens.as_slice(), worker_id)
.process_routing_decision_for_request(tokens.as_slice(), worker)
.await
.map_err(to_pyerr)?;
Ok(())
......@@ -538,7 +556,7 @@ impl ApproxKvIndexer {
#[derive(Clone)]
pub(crate) struct EndpointKvMetrics {
#[pyo3(get, set)]
pub worker_id: i64,
pub worker_id: WorkerId,
#[pyo3(get, set)]
pub request_active_slots: u64,
#[pyo3(get, set)]
......@@ -784,7 +802,7 @@ impl WorkerStats {
request_active_slots: u64,
request_total_slots: u64,
num_requests_waiting: u64,
data_parallel_rank: Option<u32>,
data_parallel_rank: Option<DpRank>,
) -> Self {
Self(RsWorkerStats {
data_parallel_rank,
......@@ -961,7 +979,7 @@ impl KvPushRouter {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, extra_args=None))]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None, dp_rank=None, extra_args=None))]
fn generate<'p>(
&self,
py: Python<'p>,
......@@ -971,7 +989,8 @@ impl KvPushRouter {
sampling_options: Option<PyObject>,
output_options: Option<PyObject>,
router_config_override: Option<PyObject>,
worker_id: Option<i64>,
worker_id: Option<WorkerId>,
dp_rank: Option<DpRank>,
extra_args: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults
......@@ -1027,6 +1046,7 @@ impl KvPushRouter {
.sampling_options(sampling_options)
.output_options(output_options)
.router_config_override(router_config_override)
.dp_rank(dp_rank)
.extra_args(extra_args);
// Set backend_instance_id if worker_id is provided
......@@ -1053,6 +1073,43 @@ impl KvPushRouter {
Self::process_request_to_stream(py, self.inner.clone(), request)
}
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
fn best_worker<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
router_config_override: Option<PyObject>,
request_id: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
depythonize(obj.bind(py)).map_err(to_pyerr)?;
Ok::<_, PyErr>(Some(override_config))
})?
} else {
None
};
let chooser = self.inner.chooser.clone();
let update_states = request_id.is_some();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (best_worker, overlap_blocks) = chooser
.find_best_match(
request_id.as_deref(),
&token_ids,
router_config_override.as_ref(),
update_states,
)
.await
.map_err(to_pyerr)?;
Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
})
}
/// Deprecated: Use `best_worker()` instead which returns (worker_id, dp_rank, overlap_blocks)
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
fn best_worker_id<'p>(
&self,
......@@ -1061,6 +1118,16 @@ impl KvPushRouter {
router_config_override: Option<PyObject>,
request_id: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
// Issue deprecation warning
let warnings = py.import("warnings")?;
warnings.call_method1(
"warn",
(
"best_worker_id() is deprecated. Use best_worker() instead which returns (worker_id, dp_rank, overlap_blocks)",
py.get_type::<pyo3::exceptions::PyDeprecationWarning>(),
),
)?;
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
......@@ -1075,7 +1142,7 @@ impl KvPushRouter {
let update_states = request_id.is_some();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = chooser
let (best_worker, overlap_blocks) = chooser
.find_best_match(
request_id.as_deref(),
&token_ids,
......@@ -1085,8 +1152,8 @@ impl KvPushRouter {
.await
.map_err(to_pyerr)?;
// Return a tuple of (worker_id, overlap_blocks)
Ok((worker_id, overlap_blocks))
// Return only worker_id and overlap_blocks for backward compatibility
Ok((best_worker.worker_id, overlap_blocks))
})
}
......@@ -1130,6 +1197,7 @@ impl KvPushRouter {
.await
.map_err(to_pyerr)?;
// Return loads without aggregation - each (worker_id, dp_rank) pair is a separate entry
// Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
Python::with_gil(|py| {
pythonize(py, &loads)
......
......@@ -44,6 +44,11 @@ impl ModelRuntimeConfig {
self.inner.reasoning_parser = reasoning_parser;
}
#[setter]
fn set_data_parallel_size(&mut self, data_parallel_size: u32) {
self.inner.data_parallel_size = data_parallel_size;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner
......
......@@ -778,16 +778,21 @@ class KvEventPublisher:
...
def __init__(
self, component: Component, worker_id: int, kv_block_size: int
self, component: Component, worker_id: int, kv_block_size: int, dp_rank: int = 0
) -> None:
"""
Create a `KvEventPublisher` object
Args:
component: The component to publish events for
worker_id: The worker ID
kv_block_size: The KV block size (must be > 0)
dp_rank: The data parallel rank (defaults to 0)
"""
def publish_stored(
self,
event_id,
int,
event_id: int,
token_ids: List[int],
num_block_tokens: List[int],
block_hashes: List[int],
......@@ -796,12 +801,24 @@ class KvEventPublisher:
) -> None:
"""
Publish a KV stored event.
Args:
event_id: The event ID
token_ids: List of token IDs
num_block_tokens: Number of tokens per block
block_hashes: List of block hashes (signed 64-bit integers)
lora_id: The LoRA ID
parent_hash: Optional parent hash (signed 64-bit integer)
"""
...
def publish_removed(self, event_id, int, block_hashes: List[int]) -> None:
def publish_removed(self, event_id: int, block_hashes: List[int]) -> None:
"""
Publish a KV removed event.
Args:
event_id: The event ID
block_hashes: List of block hashes to remove (signed 64-bit integers)
"""
...
......@@ -1199,6 +1216,7 @@ class KvPushRouter:
output_options: Optional[JsonLike] = None,
router_config_override: Optional[JsonLike] = None,
worker_id: Optional[int] = None,
dp_rank: Optional[int] = None,
) -> AsyncIterator[JsonLike]:
"""
Generate text using the KV-aware router.
......@@ -1213,6 +1231,10 @@ class KvPushRouter:
worker_id: Optional worker ID to route to directly. If set, the request
will be sent to this specific worker and router states will be
updated accordingly.
dp_rank: Optional data parallel rank to route to. If set along with worker_id,
the request will be routed to the specific (worker_id, dp_rank) pair.
If only dp_rank is set, the router will select the best worker but
force routing to the specified dp_rank.
Returns:
An async iterator yielding generation responses
......@@ -1220,10 +1242,36 @@ class KvPushRouter:
Note:
- If worker_id is set, the request bypasses KV matching and routes directly
to the specified worker while still updating router states.
- dp_rank allows targeting a specific data parallel replica when workers have
multiple replicas (data_parallel_size > 1).
- This is different from query_instance_id which doesn't route the request.
"""
...
async def best_worker(
self,
token_ids: List[int],
router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None,
) -> Tuple[int, int, int]:
"""
Find the best matching worker for the given tokens.
Args:
token_ids: List of token IDs to find matches for
router_config_override: Optional router configuration override
request_id: Optional request ID. If provided, router states will be updated
to track this request (active blocks, lifecycle events). If not
provided, this is a query-only operation that doesn't affect state.
Returns:
A tuple of (worker_id, dp_rank, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- dp_rank: The data parallel rank of the selected worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def best_worker_id(
self,
token_ids: List[int],
......@@ -1231,6 +1279,8 @@ class KvPushRouter:
request_id: Optional[str] = None,
) -> Tuple[int, int]:
"""
[DEPRECATED] Use best_worker() instead which returns (worker_id, dp_rank, overlap_blocks).
Find the best matching worker for the given tokens.
Args:
......@@ -1244,6 +1294,9 @@ class KvPushRouter:
A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found
.. deprecated::
Use :meth:`best_worker` instead which also returns dp_rank.
"""
...
......@@ -1260,8 +1313,13 @@ class KvPushRouter:
Returns:
A list of dictionaries, each containing:
- worker_id: The worker ID
- dp_rank: The data parallel rank
- potential_prefill_tokens: Number of tokens that would need prefill
- potential_decode_blocks: Number of blocks currently in decode phase
Note:
Each (worker_id, dp_rank) pair is returned as a separate entry.
If you need aggregated loads per worker_id, sum the values manually.
"""
...
......@@ -1287,7 +1345,7 @@ class KvPushRouter:
Note:
This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing.
`best_worker()` with `request_id` for custom routing.
"""
...
......@@ -1304,7 +1362,7 @@ class KvPushRouter:
Note:
This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing.
`best_worker()` with `request_id` for custom routing.
"""
...
......
......@@ -85,17 +85,21 @@ async def test_radix_tree_binding(distributed_runtime):
overlap_scores = radix_tree.find_matches([0])
# Verify the results
# Note: scores is now Dict[(worker_id, dp_rank), score]
assert overlap_scores.scores is not None
assert (
len(overlap_scores.scores) == 1
), f"Expected 1 worker in scores, got {len(overlap_scores.scores)}"
assert worker_id in overlap_scores.scores, f"Worker {worker_id} not found in scores"
worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert (
overlap_scores.scores[worker_id] == 1
), f"Expected score 1 for worker {worker_id}, got {overlap_scores.scores[worker_id]}"
worker_key in overlap_scores.scores
), f"Worker {worker_key} not found in scores"
assert (
overlap_scores.scores[worker_key] == 1
), f"Expected score 1 for worker {worker_key}, got {overlap_scores.scores[worker_key]}"
print(
f"✓ RadixTree test passed: worker {worker_id} has score {overlap_scores.scores[worker_id]}"
f"✓ RadixTree test passed: worker {worker_key} has score {overlap_scores.scores[worker_key]}"
)
......@@ -130,24 +134,25 @@ async def test_event_handler(distributed_runtime):
event_publisher.store_event(test_token, lora_id)
# wait for the event to be processed as it is sent asynchronously
# Retry loop for CI environments where processing may take longer
worker_key = (worker_id, 0) # (worker_id, dp_rank)
for retry in range(10): # Try up to 10 times
await asyncio.sleep(0.5) # Wait 500ms between retries
scores = await indexer.find_matches_for_request(test_token, lora_id)
if (
scores.scores
and worker_id in scores.scores
and scores.scores[worker_id] == 1
and worker_key in scores.scores
and scores.scores[worker_key] == 1
):
break
if retry == 9: # Last iteration
# Provide detailed error message for debugging
assert scores.scores, f"No scores found after {(retry+1)*0.5}s"
assert (
worker_id in scores.scores
), f"Worker {worker_id} not in scores after {(retry+1)*0.5}s"
worker_key in scores.scores
), f"Worker {worker_key} not in scores after {(retry+1)*0.5}s"
assert (
scores.scores[worker_id] == 1
), f"Expected score 1, got {scores.scores.get(worker_id)} after {(retry+1)*0.5}s"
scores.scores[worker_key] == 1
), f"Expected score 1, got {scores.scores.get(worker_key)} after {(retry+1)*0.5}s"
# remove event
event_publisher.remove_event()
......@@ -185,8 +190,9 @@ async def test_approx_kv_indexer(distributed_runtime):
scores = await indexer.find_matches_for_request(tokens)
assert scores.scores
assert worker_id in scores.scores
assert scores.scores[worker_id] == 2
worker_key = (worker_id, 0) # (worker_id, dp_rank)
assert worker_key in scores.scores
assert scores.scores[worker_key] == 2
class EventPublisher:
......@@ -281,7 +287,7 @@ async def metrics_publisher_task(kv_listener, expected_metrics):
expected_metrics["request_active_slots"],
expected_metrics["request_total_slots"],
expected_metrics["num_requests_waiting"],
None,
0, # data_parallel_rank (0 = DP not enabled)
)
kv_stats = KvStats(
......
......@@ -38,7 +38,9 @@ use crate::{
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block,
},
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
protocols::{
LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
scoring::ProcessedEndpoints,
subscriber::start_kv_router_background,
......@@ -74,7 +76,7 @@ pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";
pub trait WorkerSelector {
fn select_worker(
&self,
workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>;
......@@ -316,7 +318,7 @@ impl KvRouter {
}
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking
pub async fn find_best_match(
&self,
......@@ -324,7 +326,7 @@ impl KvRouter {
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
) -> anyhow::Result<(i64, u32)> {
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
// Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true");
......@@ -350,7 +352,7 @@ impl KvRouter {
(false, false) => (None, None),
};
let best_worker_id = self
let best_worker = self
.scheduler
.schedule(
context_id.map(|s| s.to_string()),
......@@ -364,17 +366,17 @@ impl KvRouter {
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
.process_routing_decision(best_worker_id, block_hashes, maybe_seq_hashes_1.unwrap())
.process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
.await
.unwrap();
};
let overlap_amount = overlap_scores
.scores
.get(&best_worker_id)
.get(&best_worker)
.copied()
.unwrap_or(0);
Ok((best_worker_id, overlap_amount))
Ok((best_worker, overlap_amount))
}
pub async fn add_request(
......@@ -382,7 +384,7 @@ impl KvRouter {
request_id: String,
tokens: &[u32],
overlap_blocks: u32,
worker_id: i64,
worker: WorkerWithDpRank,
) {
let isl_tokens = tokens.len();
......@@ -397,7 +399,7 @@ impl KvRouter {
maybe_seq_hashes,
isl_tokens,
overlap_blocks,
worker_id,
worker,
)
.await;
}
......@@ -450,12 +452,13 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
// Handle different request types
let response = match request {
RouterRequest::New { tokens } => {
let (worker_id, overlap_blocks) = self
let (best_worker, overlap_blocks) = self
.find_best_match(Some(&context_id), &tokens, None, true)
.await?;
RouterResponse::New {
worker_id,
worker_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
overlap_blocks,
}
}
......@@ -523,24 +526,45 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id");
let (instance_id, overlap_amount) = if let Some(id) = request.backend_instance_id {
// If instance_id is set, use it and manually add the request to track it
if !query_instance_id {
self.chooser
.add_request(context_id.clone(), &request.token_ids, 0, id)
.await;
let (instance_id, dp_rank, overlap_amount) = if let Some(id) =
request.backend_instance_id
{
// If instance_id is set, use it and compute actual overlap
let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_id {
tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
);
}
(id, 0)
// Compute actual overlap blocks by querying the indexer
let block_hashes =
compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);
self.chooser
.add_request(
context_id.clone(),
&request.token_ids,
overlap_blocks,
worker,
)
.await;
(id, dp_rank, overlap_blocks)
} else {
// Otherwise, find the best match
self.chooser
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(&context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id
)
.await?
.await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount)
};
// if request has the annotation "query_instance_id",
......@@ -564,6 +588,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
backend_input.dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input);
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
......
......@@ -27,11 +27,11 @@ use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::indexer::{
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent,
WorkerId, compute_block_hash_for_seq,
compute_block_hash_for_seq,
};
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
KvCacheStoredBlockData, LocalBlockHash, WorkerId, WorkerWithDpRank,
};
#[derive(Debug)]
......@@ -44,8 +44,8 @@ struct MatchRequest {
#[derive(Debug)]
struct RouterResult {
/// The id of the selected worker.
worker_id: WorkerId,
/// The worker (with dp_rank) that was selected.
worker: WorkerWithDpRank,
/// The local hashes of the tokens sent to the worker.
local_hashes: Vec<LocalBlockHash>,
......@@ -58,8 +58,8 @@ struct RouterResult {
struct TimerEntry {
/// The key of the timer.
key: ExternalSequenceBlockHash,
/// The worker id that stored this block.
worker: WorkerId,
/// The worker (with dp_rank) that stored this block.
worker: WorkerWithDpRank,
}
/// A data structure to manage a collection of timers, addressable by a key.
......@@ -237,10 +237,11 @@ impl ApproxKvIndexer {
event_id += 1;
let event = RouterEvent::new(
result.worker_id,
result.worker.worker_id,
KvCacheEvent {
event_id,
data: stored_event,
dp_rank: result.worker.dp_rank,
}
);
......@@ -248,7 +249,7 @@ impl ApproxKvIndexer {
timer_manager.insert(result.sequence_hashes.iter().map(|h| TimerEntry {
key: ExternalSequenceBlockHash(*h),
worker: result.worker_id,
worker: result.worker,
}).collect());
}
......@@ -269,12 +270,13 @@ impl ApproxKvIndexer {
event_id += 1;
let event = RouterEvent::new(
e.worker,
e.worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![e.key],
}),
dp_rank: e.worker.dp_rank,
}
);
......@@ -307,13 +309,13 @@ impl ApproxKvIndexer {
/// Core function to process a routing decision with pre-computed hashes
pub async fn process_routing_decision(
&self,
worker_id: WorkerId,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
self.route_tx
.send(RouterResult {
worker_id,
worker,
local_hashes,
sequence_hashes,
})
......@@ -327,7 +329,7 @@ impl ApproxKvIndexer {
pub async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker_id: WorkerId,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size);
......@@ -338,7 +340,7 @@ impl ApproxKvIndexer {
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.process_routing_decision(worker_id, local_hashes, sequence_hashes)
self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
}
......@@ -526,14 +528,20 @@ mod tests {
// 2. Inform indexer about routing decision
indexer
.process_routing_decision_for_request(&tokens, worker_id)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await
.unwrap();
// Poll until we observe the match being registered
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_id).copied() == Some(1)
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_id))
.copied()
== Some(1)
})
.await;
......@@ -554,14 +562,18 @@ mod tests {
let worker_id: WorkerId = 7;
indexer
.process_routing_decision_for_request(&tokens, worker_id)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_id),
)
.await
.unwrap();
// Wait until the worker is registered
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.contains_key(&worker_id)
s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
})
.await;
......@@ -571,7 +583,8 @@ mod tests {
// Ensure the worker's entries are gone
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
!s.scores.contains_key(&worker_id)
!s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
})
.await;
}
......@@ -590,19 +603,31 @@ mod tests {
// Register on both workers
indexer
.process_routing_decision_for_request(&tokens, worker_0)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await
.unwrap();
indexer
.process_routing_decision_for_request(&tokens, worker_1)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await
.unwrap();
// Ensure both workers are registered
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_0).copied() == Some(1)
&& s.scores.get(&worker_1).copied() == Some(1)
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
== Some(1)
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
})
.await;
......@@ -612,7 +637,12 @@ mod tests {
// Confirm the removed worker is gone, and the other remains.
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
!s.scores.contains_key(&worker_0) && s.scores.get(&worker_1).copied() == Some(1)
!s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
})
.await;
}
......@@ -631,14 +661,20 @@ mod tests {
// Register Sequence A on worker A
indexer
.process_routing_decision_for_request(&seq_a, worker_a)
.process_routing_decision_for_request(
&seq_a,
WorkerWithDpRank::from_worker_id(worker_a),
)
.await
.unwrap();
// Ensure the indexer has registered the block
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&seq_a).await.unwrap();
s.scores.get(&worker_a).copied() == Some(1)
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a))
.copied()
== Some(1)
})
.await;
......@@ -649,7 +685,12 @@ mod tests {
let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap();
// Expect worker A to have an overlap score of 1 (shared first block)
assert_eq!(overlap.scores.get(&worker_a), Some(&1));
assert_eq!(
overlap
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a)),
Some(&1)
);
}
/// When the same block resides on multiple workers, all should appear in the overlap scores.
......@@ -666,25 +707,47 @@ mod tests {
// Register the same sequence on two different workers
indexer
.process_routing_decision_for_request(&tokens, worker_0)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_0),
)
.await
.unwrap();
indexer
.process_routing_decision_for_request(&tokens, worker_1)
.process_routing_decision_for_request(
&tokens,
WorkerWithDpRank::from_worker_id(worker_1),
)
.await
.unwrap();
// Wait until both workers are reflected in overlap scores
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
s.scores.get(&worker_0).copied() == Some(1)
&& s.scores.get(&worker_1).copied() == Some(1)
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
== Some(1)
&& s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.copied()
== Some(1)
})
.await;
let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
assert_eq!(scores.scores.get(&worker_0), Some(&1));
assert_eq!(scores.scores.get(&worker_1), Some(&1));
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0)),
Some(&1)
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1)),
Some(&1)
);
}
}
......@@ -80,9 +80,6 @@ pub enum KvCacheEventError {
BlockNotFound,
}
/// Identifier of a LLM worker which emits events to the router.
pub type WorkerId = i64;
/// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
......@@ -200,9 +197,9 @@ impl RouterEvent {
struct RadixBlock {
/// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedRadixBlock>,
/// A map of worker IDs to their external sequence block hash for this block.
/// A map of workers (with dp_rank) to their external sequence block hash for this block.
/// The external hash is preserved to speed up snapshotting.
workers: HashMap<WorkerId, ExternalSequenceBlockHash>,
workers: HashMap<WorkerWithDpRank, ExternalSequenceBlockHash>,
/// A buffer of times that this block was last traversed
recent_uses: VecDeque<Instant>,
}
......@@ -235,7 +232,7 @@ pub struct RadixTree {
/// Transitioning to a radix tree only would require a change in the messaging structure
/// as the entire prefix would need to be sent. Alternatively, we could use block_depth
/// integers to indicate how many blocks to skip and use a radix/prefix tree at each level.
lookup: HashMap<WorkerId, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
lookup: HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
/// The time buffer the radix tree should check when considering frequence of block accesses
expiration_duration: Option<Duration>,
}
......@@ -332,11 +329,15 @@ impl RadixTree {
///
/// * `event` - The `RouterEvent` to apply.
pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> {
let (worker_id, event) = (event.worker_id, event.event);
let (id, op) = (event.event_id, event.data);
let (worker_id, kv_event) = (event.worker_id, event.event);
let (id, op) = (kv_event.event_id, kv_event.data);
// Construct WorkerWithDpRank from worker_id and dp_rank from the event
let worker = WorkerWithDpRank::new(worker_id, kv_event.dp_rank);
tracing::trace!(id, "RadixTree::apply_event: Store operation: {:?}", op);
let worker_lookup = self.lookup.entry(worker_id).or_default();
let worker_lookup = self.lookup.entry(worker).or_default();
match op {
KvCacheEventData::Stored(op) => {
......@@ -352,7 +353,8 @@ impl RadixTree {
Some(current) => current.clone(),
None => {
tracing::warn!(
worker_id = worker_id.to_string(),
worker_id = worker.worker_id.to_string(),
dp_rank = ?worker.dp_rank,
id,
parent_hash = ?op.parent_hash,
"Failed to find parent block; skipping store operation"
......@@ -381,11 +383,11 @@ impl RadixTree {
}
};
// add our worker_id to the block with its external hash
// add our worker to the block with its external hash
block
.borrow_mut()
.workers
.insert(worker_id, block_id.block_hash);
.insert(worker, block_id.block_hash);
// add the block to the worker_id lookup table
worker_lookup.insert(block_id.block_hash, block.clone());
......@@ -419,7 +421,7 @@ impl RadixTree {
};
let mut guard = entry.borrow_mut();
guard.workers.remove(&worker_id);
guard.workers.remove(&worker);
if guard.workers.is_empty() {
// if no workers are using this block, that is true for all children
guard.children.clear();
......@@ -430,48 +432,57 @@ impl RadixTree {
Ok(())
}
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker_id);
self.clear_all_blocks(worker.worker_id);
Ok(())
}
}
}
pub fn remove_worker(&mut self, worker: WorkerId) {
if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
blocks.iter().for_each(|(_, block)| {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
});
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&mut self, worker_id: WorkerId, keep_worker: bool) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = self
.lookup
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.collect();
pub fn clear_all_blocks(&mut self, worker: WorkerId) {
// Check if the worker has any blocks to clear
if let Some(blocks) = self.lookup.get(&worker) {
let blocks_to_clear: Vec<_> = blocks.values().collect();
for worker in workers {
if let Some((worker_key, blocks)) = self.lookup.remove_entry(&worker) {
blocks.iter().for_each(|(_, block)| {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
});
// Remove the worker from each block's workers map
blocks_to_clear.iter().for_each(|block| {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
if keep_worker {
// Re-insert worker with empty blocks map to keep it tracked
self.lookup.insert(worker_key, HashMap::new());
}
});
// Clear the worker's blocks
if let Some(worker_lookup) = self.lookup.get_mut(&worker) {
worker_lookup.clear();
}
}
}
pub fn remove_worker(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, false);
}
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
}
/// Get all worker IDs currently tracked in the radix tree.
/// Returns unique worker_ids (ignoring dp_rank differences).
pub fn get_workers(&self) -> Vec<WorkerId> {
self.lookup.keys().copied().collect()
let mut worker_ids: Vec<WorkerId> = self.lookup.keys().map(|w| w.worker_id).collect();
worker_ids.sort_unstable();
worker_ids.dedup();
worker_ids
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
......@@ -487,10 +498,10 @@ impl RadixTree {
let mut event_id = 0u64;
// BFS queue: (current_block, parent_hashes_per_worker, tokens_hash)
// parent_hashes_per_worker maps WorkerId -> ExternalSequenceBlockHash
// parent_hashes_per_worker maps WorkerWithDpRank -> ExternalSequenceBlockHash
let mut queue: VecDeque<(
SharedRadixBlock,
HashMap<WorkerId, ExternalSequenceBlockHash>,
HashMap<WorkerWithDpRank, ExternalSequenceBlockHash>,
LocalBlockHash,
)> = VecDeque::new();
......@@ -514,7 +525,7 @@ impl RadixTree {
// Create a store event for this worker
let event = RouterEvent {
worker_id: *worker_id,
worker_id: worker_id.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
......@@ -524,6 +535,7 @@ impl RadixTree {
tokens_hash,
}],
}),
dp_rank: worker_id.dp_rank,
},
};
events.push(event);
......@@ -639,11 +651,11 @@ impl KvIndexerMetrics {
}
}
/// Scores representing the overlap of workers.
/// Scores representing the overlap of workers (with their dp_rank).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverlapScores {
// map of worker_id to score
pub scores: HashMap<WorkerId, u32>,
// map of worker (with dp_rank) to score
pub scores: HashMap<WorkerWithDpRank, u32>,
// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub frequencies: Vec<usize>,
}
......@@ -671,10 +683,10 @@ impl OverlapScores {
///
/// ### Arguments
///
/// * `workers` - An iterator over `WorkerId` references.
/// * `workers` - An iterator over `WorkerWithDpRank` references.
pub fn update_scores<'a, I>(&mut self, workers: I)
where
I: IntoIterator<Item = &'a WorkerId>,
I: IntoIterator<Item = &'a WorkerWithDpRank>,
{
for worker in workers {
let score = self.scores.entry(*worker).or_insert(0);
......@@ -1344,6 +1356,7 @@ mod tests {
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
......@@ -1359,6 +1372,7 @@ mod tests {
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
......@@ -1379,10 +1393,22 @@ mod tests {
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(trie.lookup.len(), 1);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
......@@ -1415,12 +1441,36 @@ mod tests {
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
assert_eq!(scores.scores.get(&worker_2).unwrap(), &1);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&1
);
assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 3);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
3
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
......@@ -1449,8 +1499,20 @@ mod tests {
trie.apply_event(create_remove_event(worker_2, 2, vec![5]))
.unwrap();
assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 2);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
2
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
......@@ -1480,8 +1542,20 @@ mod tests {
.unwrap();
assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 1);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
1
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
......@@ -1519,12 +1593,36 @@ mod tests {
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
assert_eq!(scores.scores.get(&worker_2).unwrap(), &2);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&2
);
assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 4);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
4
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
......@@ -1551,7 +1649,7 @@ mod tests {
);
assert_eq!(
trie.lookup
.get(&worker_1)
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap()
......@@ -1562,7 +1660,7 @@ mod tests {
);
assert_eq!(
trie.lookup
.get(&worker_2)
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap()
......@@ -1620,12 +1718,16 @@ mod tests {
.unwrap();
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);
assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 1
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
trie.remove_worker(worker_0);
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(result.len() == 1 && result[&worker_1] == 1);
assert!(result.len() == 1 && result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1);
}
#[test]
......@@ -1643,7 +1745,11 @@ mod tests {
// Test clearing an empty worker
trie.clear_all_blocks(worker_0);
assert!(!trie.lookup.contains_key(&worker_0));
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
// Test clearing a worker with shared blocks
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None))
......@@ -1652,17 +1758,29 @@ mod tests {
.unwrap();
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);
assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 1
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
trie.clear_all_blocks(worker_0);
assert!(trie.lookup.contains_key(&worker_0));
assert!(trie.lookup.get(&worker_0).unwrap().is_empty());
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.is_empty()
);
let result = trie
.find_matches(vec![LocalBlockHash(0), LocalBlockHash(2)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 2);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 2);
let result = trie
.find_matches(
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(3)],
......@@ -1670,7 +1788,7 @@ mod tests {
)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
// Test re-adding blocks after clearing worker
trie.apply_event(create_store_event(worker_0, 0, vec![4, 5], None))
......@@ -1679,19 +1797,32 @@ mod tests {
.find_matches(vec![LocalBlockHash(4), LocalBlockHash(5)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_0], 2);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_0)], 2);
// Test multiple clears
trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_0);
assert!(trie.lookup.contains_key(&worker_0));
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
// Test clearing all workers
trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_1);
assert!(!trie.lookup.is_empty());
assert!(trie.lookup.get(&worker_0).unwrap().is_empty());
assert!(trie.lookup.get(&worker_1).unwrap().is_empty());
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.is_empty()
);
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.is_empty()
);
// Test clearing a worker that has been removed
trie.apply_event(create_store_event(worker_0, 0, vec![6], None))
......@@ -1700,20 +1831,35 @@ mod tests {
.unwrap();
trie.remove_worker(worker_0);
trie.clear_all_blocks(worker_0);
assert!(!trie.lookup.contains_key(&worker_0));
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
// Test clearing a worker that doesn't exist
let worker_fake = 2;
assert!(!trie.lookup.contains_key(&worker_fake));
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_fake))
);
trie.clear_all_blocks(worker_fake);
assert!(!trie.lookup.contains_key(&worker_fake));
assert!(trie.lookup.contains_key(&worker_1));
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_fake))
);
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
}
#[test]
......@@ -1736,12 +1882,20 @@ mod tests {
)
.scores;
assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
let result = trie
.find_matches(vec![LocalBlockHash(0), LocalBlockHash(1)], true)
.scores;
assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
}
#[rstest]
......@@ -1968,6 +2122,7 @@ mod tests {
tokens_hash: LocalBlockHash(13226331709069118873),
}],
}),
dp_rank: 0,
};
let router_event = RouterEvent::new(worker_id, kv_cache_event);
......@@ -2239,49 +2394,94 @@ mod tests {
.unwrap();
// Verify worker_0 has 3 blocks in lookup
assert_eq!(trie.lookup.get(&worker_0).unwrap().len(), 3);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.len(),
3
);
// Verify that blocks have the correct workers
let block_1 = trie
.lookup
.get(&worker_0)
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
assert!(block_1.borrow().workers.contains_key(&worker_0));
assert!(block_1.borrow().workers.contains_key(&worker_1));
assert!(block_1.borrow().workers.contains_key(&worker_2));
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Remove worker_0
trie.remove_worker(worker_0);
// Verify worker_0 is completely removed from lookup table
assert!(!trie.lookup.contains_key(&worker_0));
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert_eq!(trie.lookup.len(), 2);
// Verify that worker_0's hash is removed from the workers set
let block_1 = trie
.lookup
.get(&worker_1)
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
assert!(!block_1.borrow().workers.contains_key(&worker_0));
assert!(block_1.borrow().workers.contains_key(&worker_1));
assert!(block_1.borrow().workers.contains_key(&worker_2));
assert!(
!block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children
let block_2 = trie
.lookup
.get(&worker_1)
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap();
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
assert!(block_2.borrow().workers.contains_key(&worker_1));
assert!(
block_2
.borrow()
.workers
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
// Verify match results no longer include worker_0
let result = trie
......@@ -2291,8 +2491,8 @@ mod tests {
)
.scores;
assert_eq!(result.len(), 2);
assert!(!result.contains_key(&worker_0));
assert!(result.contains_key(&worker_1));
assert!(result.contains_key(&worker_2));
assert!(!result.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_2)));
}
}
......@@ -5,6 +5,34 @@ use crate::tokens::{SequenceHash, Token};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A worker identifier.
pub type WorkerId = i64;
/// A data parallel rank identifier.
pub type DpRank = u32;
/// A worker identifier combined with its data parallel rank.
/// Used for routing decisions in data parallel setups.
/// dp_rank = 0 indicates either DP not enabled or the first rank.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct WorkerWithDpRank {
pub worker_id: WorkerId,
pub dp_rank: DpRank,
}
impl WorkerWithDpRank {
pub fn new(worker_id: WorkerId, dp_rank: DpRank) -> Self {
Self { worker_id, dp_rank }
}
pub fn from_worker_id(worker_id: WorkerId) -> Self {
Self {
worker_id,
dp_rank: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterRequest {
......@@ -26,15 +54,24 @@ impl Default for RouterRequest {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
pub enum RouterResponse {
New { worker_id: i64, overlap_blocks: u32 },
PrefillMarked { success: bool },
FreeMarked { success: bool },
New {
worker_id: WorkerId,
#[serde(default)]
dp_rank: DpRank,
overlap_blocks: u32,
},
PrefillMarked {
success: bool,
},
FreeMarked {
success: bool,
},
}
#[derive(Debug)]
pub struct WorkerSelectionResult {
/// The worker id of the selected worker
pub worker_id: i64,
/// The full worker information including dp_rank
pub worker: WorkerWithDpRank,
/// The total number of blocks required to prefill the request
pub required_blocks: u64,
......@@ -54,7 +91,7 @@ pub struct ForwardPassMetrics {
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct WorkerStats {
// https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
pub data_parallel_rank: Option<u32>,
pub data_parallel_rank: Option<DpRank>,
pub request_active_slots: u64,
pub request_total_slots: u64,
pub num_requests_waiting: u64,
......@@ -136,7 +173,7 @@ impl From<i64> for ExternalSequenceBlockHash {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillEvent {
pub request_id: String,
pub worker_id: i64,
pub worker_id: WorkerId,
pub data: PrefillEventData,
pub router_id: Uuid,
}
......@@ -155,7 +192,7 @@ pub enum PrefillEventData {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ActiveSequenceEvent {
pub request_id: String,
pub worker_id: i64,
pub worker: WorkerWithDpRank,
pub data: ActiveSequenceEventData,
pub router_id: Uuid,
}
......@@ -199,6 +236,9 @@ pub struct KvCacheEvent {
pub event_id: u64,
/// The data associated with the event.
pub data: KvCacheEventData,
/// The data parallel rank of the worker emitting this event (0 if DP not enabled).
#[serde(default)]
pub dp_rank: DpRank,
}
/// Represents the data associated with a cache event.
......@@ -313,6 +353,7 @@ mod tests {
let event = KvCacheEvent {
event_id: 1,
data: event_data,
dp_rank: 0,
};
let events = KvCacheEvents {
......
......@@ -326,13 +326,16 @@ pub async fn start_zmq_listener(
};
tracing::trace!(
"ZMQ listener on {} received batch with {} events (seq={})",
"ZMQ listener on {} received batch with {} events (seq={}, dp_rank={})",
zmq_endpoint,
batch.events.len(),
seq
seq,
batch.data_parallel_rank
);
let dp_rank = batch.data_parallel_rank;
for raw_event in batch.events.into_iter() {
let event = convert_event(raw_event, seq, kv_block_size, &warning_count);
let event = convert_event(raw_event, seq, kv_block_size, dp_rank, &warning_count);
if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped";
......@@ -356,6 +359,7 @@ fn convert_event(
raw: RawKvEvent,
event_id: u64,
kv_block_size: u32,
dp_rank: u32,
warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent {
match raw {
......@@ -387,6 +391,7 @@ fn convert_event(
warning_count,
),
}),
dp_rank,
}
}
RawKvEvent::BlockRemoved { block_hashes, .. } => {
......@@ -400,11 +405,13 @@ fn convert_event(
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes,
}),
dp_rank,
}
}
RawKvEvent::AllBlocksCleared => KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
dp_rank,
},
}
}
......@@ -1014,7 +1021,7 @@ mod test_event_processing {
medium: None,
};
let out = convert_event(raw_evt, 42, kv_block_size, &Arc::new(AtomicU32::new(0)));
let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Stored(_)));
}
......@@ -1025,7 +1032,7 @@ mod test_event_processing {
block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)],
medium: None,
};
let out = convert_event(raw_evt, 7, kv_block_size, &Arc::new(AtomicU32::new(0)));
let out = convert_event(raw_evt, 7, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Removed(_)));
}
......@@ -1034,7 +1041,7 @@ mod test_event_processing {
fn test_convert_event_all_blocks_cleared() {
let kv_block_size = 4;
let raw_evt = RawKvEvent::AllBlocksCleared;
let out = convert_event(raw_evt, 1, kv_block_size, &Arc::new(AtomicU32::new(0)));
let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Cleared));
}
}
......@@ -1115,6 +1122,7 @@ mod tests_startup_helpers {
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)],
}),
dp_rank: 0,
};
let token = CancellationToken::new();
......
......@@ -12,7 +12,6 @@ mod tests {
use super::*;
use crate::kv_router::indexer::KvIndexer;
use crate::kv_router::indexer::KvIndexerMetrics;
use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::*;
use std::time::Duration;
use tempfile::tempdir;
......@@ -50,6 +49,7 @@ mod tests {
KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
)
}
......@@ -65,6 +65,7 @@ mod tests {
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
)
}
......
......@@ -18,21 +18,24 @@ use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::indexer::OverlapScores;
use super::protocols::WorkerSelectionResult;
use super::protocols::{DpRank, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::sequence::ActiveSequencesMultiWorker;
use crate::tokens::SequenceHash;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
pub worker_id: i64,
pub worker_id: WorkerId,
#[serde(default)]
pub dp_rank: DpRank,
pub isl_blocks: usize,
pub overlap_blocks: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
pub worker_id: i64,
pub worker_id: WorkerId,
pub dp_rank: DpRank,
pub potential_prefill_tokens: usize,
pub potential_decode_blocks: usize,
}
......@@ -51,7 +54,7 @@ pub enum KvSchedulerError {
#[derive(Debug)]
pub struct SchedulingResponse {
pub best_worker_id: i64,
pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32,
}
......@@ -60,8 +63,8 @@ pub struct SchedulingRequest {
pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize,
pub overlaps: OverlapScores,
pub decode_blocks: HashMap<i64, usize>,
pub prefill_tokens: HashMap<i64, usize>,
pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
// Router config overrides for this specific request
pub router_config_override: Option<RouterConfigOverride>,
// Whether to update scheduler states (false for query_instance_id requests)
......@@ -94,17 +97,18 @@ impl KvScheduler {
component: Component,
block_size: u32,
instances_rx: watch::Receiver<Vec<Instance>>,
runtime_configs_rx: watch::Receiver<HashMap<i64, ModelRuntimeConfig>>,
runtime_configs_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone();
let runtime_configs: HashMap<i64, ModelRuntimeConfig> = runtime_configs_rx.borrow().clone();
let runtime_configs: HashMap<WorkerId, ModelRuntimeConfig> =
runtime_configs_rx.borrow().clone();
// Create shared workers_with_configs wrapped in Arc<RwLock>
let workers_with_configs: Arc<RwLock<HashMap<i64, Option<ModelRuntimeConfig>>>> = {
let workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>> = {
let mut initial_map = HashMap::new();
for instance in &instances {
let worker_id = instance.instance_id;
......@@ -117,14 +121,10 @@ impl KvScheduler {
Arc::new(RwLock::new(initial_map))
};
let worker_ids: Vec<i64> = instances
.iter()
.map(|instance| instance.instance_id)
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(),
block_size as usize,
worker_ids,
workers_with_configs.read().await.clone(), // this includes dp_size info
replica_sync,
router_uuid,
));
......@@ -162,24 +162,23 @@ impl KvScheduler {
let new_instances = instances_monitor_rx.borrow_and_update().clone();
let new_configs = configs_monitor_rx.borrow_and_update().clone();
// Update workers when instances change
let worker_ids: Vec<i64> = new_instances
.iter()
.map(|instance| instance.instance_id)
.collect();
slots_monitor.update_workers(worker_ids);
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
workers_map.clear();
// Build the new workers_with_configs map
let mut new_workers_with_configs = HashMap::new();
for instance in &new_instances {
let worker_id = instance.instance_id;
let config = new_configs.get(&worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
workers_map.insert(worker_id, config);
new_workers_with_configs.insert(worker_id, config);
}
// Update workers when instances change
slots_monitor.update_workers(new_workers_with_configs.clone());
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
*workers_map = new_workers_with_configs;
tracing::trace!(
"Updated workers_with_configs with {} workers",
workers_map.len()
......@@ -229,7 +228,8 @@ impl KvScheduler {
match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => {
let event = KVHitRateEvent {
worker_id: selection.worker_id,
worker_id: selection.worker.worker_id,
dp_rank: selection.worker.dp_rank,
isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks,
};
......@@ -238,7 +238,7 @@ impl KvScheduler {
}
let response = SchedulingResponse {
best_worker_id: selection.worker_id,
best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks,
};
request.respond(response);
......@@ -261,7 +261,7 @@ impl KvScheduler {
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
selection.worker,
)
.await
{
......@@ -302,7 +302,7 @@ impl KvScheduler {
overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
) -> Result<i64, KvSchedulerError> {
) -> Result<WorkerWithDpRank, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
......@@ -324,8 +324,7 @@ impl KvScheduler {
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
let best_worker_id = response.best_worker_id;
Ok(best_worker_id)
Ok(response.best_worker)
}
pub async fn add_request(
......@@ -334,11 +333,11 @@ impl KvScheduler {
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
worker_id: i64,
worker: WorkerWithDpRank,
) {
let _ = self
.slots
.add_request(request_id, token_sequence, isl, overlap, worker_id)
.add_request(request_id, token_sequence, isl, overlap, worker)
.await;
}
......@@ -363,21 +362,22 @@ impl KvScheduler {
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
.await;
// Get all unique worker IDs from both hashmaps
let mut worker_ids: HashSet<i64> = HashSet::new();
worker_ids.extend(decode_blocks.keys().copied());
worker_ids.extend(prefill_tokens.keys().copied());
// Get all unique WorkerWithDpRank from both hashmaps
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
// Create PotentialLoad for each worker
let mut loads = Vec::new();
for worker_id in worker_ids {
for worker in workers {
loads.push(PotentialLoad {
worker_id,
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens
.get(&worker_id)
.get(&worker)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker_id).copied().unwrap_or(0),
potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
});
}
......@@ -386,7 +386,7 @@ impl KvScheduler {
}
// Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) -> WorkerWithDpRank {
if logits.is_empty() {
panic!("Empty logits for softmax sampling");
}
......@@ -474,7 +474,7 @@ impl DefaultWorkerSelector {
impl WorkerSelector for DefaultWorkerSelector {
fn select_worker(
&self,
workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
workers: &HashMap<WorkerId, Option<ModelRuntimeConfig>>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
......@@ -494,38 +494,52 @@ impl WorkerSelector for DefaultWorkerSelector {
let mut worker_logits = HashMap::new();
let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker
for worker_id in workers.keys() {
let overlap = *overlaps.get(worker_id).unwrap_or(&0);
// this is the number of prefill tokens the worker would have if the request were scheduled there
let prefill_token = *prefill_tokens.get(worker_id).unwrap_or(&isl);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
// this is the number of decode blocks the worker would have if the request were scheduled there
let decode_block = *decode_blocks
.get(worker_id)
.unwrap_or(&(potential_prefill_block.floor() as usize))
as f64;
// Use override if provided, otherwise use default config
let overlap_weight = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight);
// Calculate logit (lower is better)
let logit = overlap_weight * potential_prefill_block + decode_block;
max_logit = max_logit.max(logit);
worker_logits.insert(*worker_id, logit);
tracing::info!(
"Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}"
);
// Calculate logits for each worker with dp_rank
// Outer loop: iterate over all workers from runtime config
// Inner loop: iterate over all dp_ranks for each worker
for (worker_id, config) in workers.iter() {
// Get data_parallel_size from runtime config
// data_parallel_size defaults to 1 in ModelRuntimeConfig
let data_parallel_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1); // Fallback if config is None
// Iterate over all dp_ranks for this worker
for dp_rank in 0..data_parallel_size {
let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
// Get overlap for this worker (defaults to 0 if not in overlaps)
let overlap = *overlaps.get(&worker).unwrap_or(&0);
// this is the number of prefill tokens the worker would have if the request were scheduled there
let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
// this is the number of decode blocks the worker would have if the request were scheduled there
let decode_block = *decode_blocks
.get(&worker)
.unwrap_or(&(potential_prefill_block.floor() as usize))
as f64;
// Use override if provided, otherwise use default config
let overlap_weight = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight);
// Calculate logit (lower is better)
let logit = overlap_weight * potential_prefill_block + decode_block;
max_logit = max_logit.max(logit);
worker_logits.insert(worker, logit);
tracing::info!(
"Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
worker.dp_rank
);
}
}
// Use softmax sampling to select worker
......@@ -535,29 +549,32 @@ impl WorkerSelector for DefaultWorkerSelector {
.as_ref()
.and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature);
let best_worker_id = softmax_sample(&worker_logits, temperature);
let best_logit = worker_logits[&best_worker_id];
let best_worker = softmax_sample(&worker_logits, temperature);
let best_logit = worker_logits[&best_worker];
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
let best_overlap = *overlaps.get(&best_worker_id).unwrap_or(&0);
// this is a runtime config set on a per worker basis, not per dp-rank
let total_blocks_info = workers
.get(&best_worker_id)
.get(&best_worker.worker_id)
.and_then(|cfg| cfg.as_ref())
.and_then(|cfg| cfg.total_kv_blocks)
.map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default();
tracing::info!(
"Selected worker: {}, logit: {:.3}, cached blocks: {}{}",
best_worker_id,
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}{}",
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
best_overlap,
total_blocks_info
);
Ok(WorkerSelectionResult {
worker_id: best_worker_id,
worker: best_worker,
required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0),
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
})
}
}
......@@ -570,54 +587,61 @@ mod tests {
fn test_softmax_sample_single_key() {
// Test that with a single key, softmax_sample always returns that key
let mut logits = HashMap::new();
let worker_id = 42;
logits.insert(worker_id, 0.5); // The value doesn't matter
let worker = WorkerWithDpRank::from_worker_id(42);
logits.insert(worker, 0.5); // The value doesn't matter
// Test with different temperatures
for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature);
assert_eq!(result, worker_id, "Should return the only available worker");
assert_eq!(result, worker, "Should return the only available worker");
}
// Test with different logit values
logits.clear();
logits.insert(worker_id, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.insert(worker, -100.0); // Very negative value
assert_eq!(softmax_sample(&logits, 1.0), worker);
logits.clear();
logits.insert(worker_id, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.insert(worker, 100.0); // Very positive value
assert_eq!(softmax_sample(&logits, 1.0), worker);
logits.clear();
logits.insert(worker_id, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker_id);
logits.insert(worker, 0.0); // Zero value
assert_eq!(softmax_sample(&logits, 1.0), worker);
}
#[test]
fn test_softmax_sample_zero_temperature() {
// Test that with temperature 0, softmax_sample returns the key with smallest logit
let mut logits = HashMap::new();
logits.insert(1, 5.0);
logits.insert(2, 3.0); // This has the smallest logit
logits.insert(3, 7.0);
logits.insert(4, 3.5);
let worker1 = WorkerWithDpRank::from_worker_id(1);
let worker2 = WorkerWithDpRank::from_worker_id(2);
let worker3 = WorkerWithDpRank::from_worker_id(3);
let worker4 = WorkerWithDpRank::from_worker_id(4);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); // This has the smallest logit
logits.insert(worker3, 7.0);
logits.insert(worker4, 3.5);
// With temperature 0, should always return worker 2 (smallest logit)
for _ in 0..10 {
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result, 2,
result, worker2,
"Should return worker with smallest logit when temperature is 0"
);
}
// Test with negative values
logits.clear();
logits.insert(10, -1.0);
logits.insert(20, -5.0); // This has the smallest logit
logits.insert(30, 0.0);
let worker10 = WorkerWithDpRank::from_worker_id(10);
let worker20 = WorkerWithDpRank::from_worker_id(20);
let worker30 = WorkerWithDpRank::from_worker_id(30);
logits.insert(worker10, -1.0);
logits.insert(worker20, -5.0); // This has the smallest logit
logits.insert(worker30, 0.0);
let result = softmax_sample(&logits, 0.0);
assert_eq!(result, 20, "Should handle negative logits correctly");
assert_eq!(result, worker20, "Should handle negative logits correctly");
}
}
......@@ -23,7 +23,6 @@
//! requests share common prefixes (e.g., system prompts, few-shot examples).
use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::indexer::WorkerId;
use crate::tokens::SequenceHash;
use anyhow::Result;
use dashmap::DashMap;
......@@ -39,8 +38,9 @@ use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;
use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData};
use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, WorkerWithDpRank};
use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::CancellationToken;
/// Duration after which stale requests are forcibly expired (5 minutes)
......@@ -280,9 +280,9 @@ enum UpdateSequences {
/// Multi-worker extension of ActiveSequences that distributes requests across multiple threads
pub struct ActiveSequencesMultiWorker {
senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
handles: Arc<DashMap<WorkerId, std::thread::JoinHandle<()>>>,
senders: Arc<DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
handles: Arc<DashMap<WorkerWithDpRank, std::thread::JoinHandle<()>>>,
block_size: usize,
component: Component,
router_id: Uuid,
......@@ -293,7 +293,7 @@ impl ActiveSequencesMultiWorker {
pub fn new(
component: Component,
block_size: usize,
worker_ids: Vec<WorkerId>,
workers_with_configs: HashMap<i64, Option<ModelRuntimeConfig>>,
replica_sync: bool,
router_uuid: String,
) -> Self {
......@@ -311,12 +311,18 @@ impl ActiveSequencesMultiWorker {
Uuid::new_v4()
});
for worker_id in worker_ids {
// Create a child cancellation token from the component's runtime
let cancel_token = component.drt().runtime().child_token();
let (sender, handle) = Self::start_worker(block_size, cancel_token);
senders.insert(worker_id, sender);
handles.insert(worker_id, handle);
// Expand workers by their dp_rank
for (worker_id, config) in workers_with_configs {
let dp_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1);
for dp_rank in 0..dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
// Create a child cancellation token from the component's runtime
let cancel_token = component.drt().runtime().child_token();
let (sender, handle) = Self::start_worker(block_size, cancel_token);
senders.insert(worker, sender);
handles.insert(worker, handle);
}
}
let multi_worker = Self {
......@@ -458,8 +464,10 @@ impl ActiveSequencesMultiWorker {
/// Background task to subscribe to active sequence events and update all workers
async fn subscribe_to_events(
senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
senders: Arc<
DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>,
>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
component: Component,
router_id: Uuid,
cancel_token: CancellationToken,
......@@ -496,9 +504,9 @@ impl ActiveSequencesMultiWorker {
isl,
overlap,
} => {
request_to_worker.insert(event.request_id.clone(), event.worker_id);
request_to_worker.insert(event.request_id.clone(), event.worker);
if let Some(sender) = senders.get(&event.worker_id) {
if let Some(sender) = senders.get(&event.worker) {
// For replicated events, we create a dummy response channel since we don't need to handle expired requests
let (resp_tx, _) = tokio::sync::oneshot::channel();
let _ = sender.send(UpdateSequences::AddRequest {
......@@ -510,14 +518,14 @@ impl ActiveSequencesMultiWorker {
});
} else {
tracing::warn!(
"Worker {} not found, cannot process AddRequest",
event.worker_id
"Worker {:?} not found, cannot process AddRequest",
event.worker
);
}
}
ActiveSequenceEventData::Free => {
if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id)
&& let Some(sender) = senders.get(&worker_id)
if let Some((_, worker)) = request_to_worker.remove(&event.request_id)
&& let Some(sender) = senders.get(&worker)
{
let _ = sender.send(UpdateSequences::Free {
request_id: event.request_id.clone(),
......@@ -525,8 +533,8 @@ impl ActiveSequencesMultiWorker {
}
}
ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker_id) = request_to_worker.get(&event.request_id)
&& let Some(sender) = senders.get(&*worker_id)
if let Some(worker) = request_to_worker.get(&event.request_id)
&& let Some(sender) = senders.get(&*worker)
{
let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
request_id: event.request_id.clone(),
......@@ -547,41 +555,53 @@ impl ActiveSequencesMultiWorker {
}
/// Update the set of workers, adding and removing as needed
pub fn update_workers(&self, new_worker_ids: Vec<WorkerId>) {
let current_workers: HashSet<WorkerId> =
pub fn update_workers(
&self,
new_workers_with_configs: HashMap<i64, Option<ModelRuntimeConfig>>,
) {
let current_workers: HashSet<WorkerWithDpRank> =
self.senders.iter().map(|entry| *entry.key()).collect();
let new_workers: HashSet<WorkerId> = new_worker_ids.into_iter().collect();
let workers_to_remove: Vec<WorkerId> =
// Expand new workers by their dp_rank
let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (worker_id, config) in &new_workers_with_configs {
let dp_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1);
for dp_rank in 0..dp_size {
new_workers.insert(WorkerWithDpRank::new(*worker_id, dp_rank));
}
}
let workers_to_remove: Vec<WorkerWithDpRank> =
current_workers.difference(&new_workers).copied().collect();
let workers_to_add: Vec<WorkerId> =
let workers_to_add: Vec<WorkerWithDpRank> =
new_workers.difference(&current_workers).copied().collect();
// Remove workers
for worker_id in &workers_to_remove {
tracing::warn!("Removing worker {}", worker_id);
// Remove workers (this will naturally remove all dp ranks for a worker_id)
for worker in &workers_to_remove {
tracing::warn!("Removing worker {:?}", worker);
// Send shutdown command to the worker
if let Some((_, sender)) = self.senders.remove(worker_id) {
if let Some((_, sender)) = self.senders.remove(worker) {
let _ = sender.send(UpdateSequences::Shutdown);
}
self.handles.remove(worker_id);
self.handles.remove(worker);
// Clean up request_to_worker mappings for this worker
self.request_to_worker
.retain(|_request_id, mapped_worker_id| *mapped_worker_id != *worker_id);
.retain(|_request_id, mapped_worker| mapped_worker != worker);
}
// Add new workers
for worker_id in &workers_to_add {
tracing::warn!("Adding worker {}", worker_id);
for worker in &workers_to_add {
tracing::warn!("Adding worker {:?}", worker);
let (sender, handle) = Self::start_worker(
self.block_size,
self.component.drt().runtime().child_token(),
);
self.senders.insert(*worker_id, sender);
self.handles.insert(*worker_id, handle);
self.senders.insert(*worker, sender);
self.handles.insert(*worker, handle);
}
}
......@@ -591,10 +611,10 @@ impl ActiveSequencesMultiWorker {
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
worker_id: WorkerId,
worker: WorkerWithDpRank,
) -> Result<()> {
if !self.senders.contains_key(&worker_id) {
return Err(anyhow::anyhow!("Worker ID {worker_id} not found"));
if !self.senders.contains_key(&worker) {
return Err(anyhow::anyhow!("Worker {:?} not found", worker));
}
// Create response channel
......@@ -604,7 +624,7 @@ impl ActiveSequencesMultiWorker {
if self.replica_sync {
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker_id,
worker,
data: ActiveSequenceEventData::AddRequest {
token_sequence: token_sequence.clone(),
isl,
......@@ -617,11 +637,11 @@ impl ActiveSequencesMultiWorker {
.await?;
}
// Update local state
self.request_to_worker.insert(request_id.clone(), worker_id);
// Update local state with full WorkerWithDpRank
self.request_to_worker.insert(request_id.clone(), worker);
self.senders
.get(&worker_id)
.get(&worker)
.unwrap()
.send(UpdateSequences::AddRequest {
request_id,
......@@ -646,7 +666,7 @@ impl ActiveSequencesMultiWorker {
}
pub async fn free(&self, request_id: &RequestId) -> Result<()> {
let worker_id = self
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
......@@ -656,7 +676,7 @@ impl ActiveSequencesMultiWorker {
if self.replica_sync {
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker_id,
worker,
data: ActiveSequenceEventData::Free,
router_id: self.router_id,
};
......@@ -667,7 +687,7 @@ impl ActiveSequencesMultiWorker {
// Update local state
self.senders
.get(&worker_id)
.get(&worker)
.unwrap()
.send(UpdateSequences::Free {
request_id: request_id.clone(),
......@@ -681,7 +701,7 @@ impl ActiveSequencesMultiWorker {
/// Mark prefill as completed for a request
pub async fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<()> {
let worker_id = self
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
......@@ -691,7 +711,7 @@ impl ActiveSequencesMultiWorker {
if self.replica_sync {
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker_id,
worker,
data: ActiveSequenceEventData::MarkPrefillCompleted,
router_id: self.router_id,
};
......@@ -702,7 +722,7 @@ impl ActiveSequencesMultiWorker {
// Update local state
self.senders
.get(&worker_id)
.get(&worker)
.unwrap()
.send(UpdateSequences::MarkPrefillCompleted {
request_id: request_id.clone(),
......@@ -727,33 +747,33 @@ impl ActiveSequencesMultiWorker {
Option<Arc<Vec<SequenceHash>>>,
tokio::sync::oneshot::Sender<T>,
) -> UpdateSequences,
) -> HashMap<WorkerId, T> {
) -> HashMap<WorkerWithDpRank, T> {
let mut results = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new();
// Send queries to all workers in parallel
for entry in self.senders.iter() {
let worker_id = *entry.key();
let worker = *entry.key();
let sender = entry.value();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
receivers.push((worker_id, resp_rx));
receivers.push((worker, resp_rx));
if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) {
tracing::error!("Failed to send command to worker {}: {}", worker_id, e);
tracing::error!("Failed to send command to worker {:?}: {}", worker, e);
}
}
// Collect results from all workers
for (worker_id, receiver) in receivers {
for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
Ok(Ok(result)) => {
results.insert(worker_id, result);
results.insert(worker, result);
}
Ok(Err(_)) => {
tracing::error!("Worker {} dropped response channel", worker_id);
tracing::error!("Worker {:?} dropped response channel", worker);
}
Err(_) => {
tracing::error!("Timeout waiting for response from worker {}", worker_id);
tracing::error!("Timeout waiting for response from worker {:?}", worker);
}
}
}
......@@ -762,7 +782,10 @@ impl ActiveSequencesMultiWorker {
}
/// Query all workers for the number of new blocks that would be added by a token sequence
pub async fn new_blocks(&self, token_sequence: Vec<SequenceHash>) -> HashMap<WorkerId, usize> {
pub async fn new_blocks(
&self,
token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::NewBlocks {
token_sequence: ts,
......@@ -777,7 +800,7 @@ impl ActiveSequencesMultiWorker {
pub async fn potential_blocks(
&self,
token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerId, usize> {
) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::PotentialBlocks {
token_sequence: ts,
......@@ -794,45 +817,49 @@ impl ActiveSequencesMultiWorker {
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlaps: OverlapScores,
) -> (HashMap<WorkerId, usize>, HashMap<WorkerId, usize>) {
) -> (
HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>,
) {
let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new();
// Send queries to all workers in parallel
for entry in self.senders.iter() {
let worker_id = *entry.key();
let sender = entry.value();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
receivers.push((worker_id, resp_rx));
// Iterate through overlaps to process each WorkerWithDpRank
for (worker, overlap) in overlaps.scores.iter() {
// Check if the worker has a sender
if let Some(sender) = self.senders.get(worker) {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
receivers.push((*worker, resp_rx));
if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens {
token_sequence: token_sequence_shared.clone(),
isl,
overlap: overlaps.scores.get(&worker_id).copied().unwrap_or(0),
resp_tx,
}) {
tracing::error!(
"Failed to send potential_tokens command to worker {}: {}",
worker_id,
e
);
if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens {
token_sequence: token_sequence_shared.clone(),
isl,
overlap: *overlap,
resp_tx,
}) {
tracing::error!(
"Failed to send potential_tokens command to worker {:?}: {}",
worker,
e
);
}
}
}
// Collect results from all workers
for (worker_id, receiver) in receivers {
for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
Ok(Ok((blocks, tokens))) => {
potential_blocks.insert(worker_id, blocks);
potential_tokens.insert(worker_id, tokens);
potential_blocks.insert(worker, blocks);
potential_tokens.insert(worker, tokens);
}
Ok(Err(_)) => {
tracing::error!("Worker {} dropped response channel", worker_id);
tracing::error!("Worker {:?} dropped response channel", worker);
}
Err(_) => {
tracing::error!("Timeout waiting for response from worker {}", worker_id);
tracing::error!("Timeout waiting for response from worker {:?}", worker);
}
}
}
......@@ -841,13 +868,13 @@ impl ActiveSequencesMultiWorker {
}
/// Query all workers for their current number of active blocks
pub async fn active_blocks(&self) -> HashMap<WorkerId, usize> {
pub async fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
.await
}
/// Query all workers for their current number of active tokens
pub async fn active_tokens(&self) -> HashMap<WorkerId, usize> {
pub async fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx })
.await
}
......@@ -918,20 +945,33 @@ mod tests {
.create()
.await?;
// Create multi-worker sequence managers with ALL workers [0, 1, 2]
// Both use the same component to ensure event synchronization works
let worker_ids = vec![0, 1, 2];
// Create multi-worker sequence managers with:
// - Worker 0 with dp_size=2 (dp_ranks 0 and 1)
// - Worker 1 with dp_size=1 (dp_rank 0)
// This gives us 3 effective workers total to test dp_rank effect
// Both seq_managers use the same component to ensure event synchronization works
let mut workers_with_configs = HashMap::new();
// Create runtime config for worker 0 with dp_size=2
let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
config_worker_0.data_parallel_size = 2;
workers_with_configs.insert(0, Some(config_worker_0));
// Create runtime config for worker 1 with dp_size=1 (default)
let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
workers_with_configs.insert(1, Some(config_worker_1));
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(),
block_size,
worker_ids.clone(),
workers_with_configs.clone(),
true,
Uuid::new_v4().to_string(),
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
worker_ids,
workers_with_configs,
true,
Uuid::new_v4().to_string(),
));
......@@ -941,36 +981,36 @@ mod tests {
// PHASE 1: Add requests using both seq_manager_1 and seq_manager_2
// Add request_0 to worker 0: sequence [0, 1, 2]
// Add request_0 to worker 0, dp_rank 0: sequence [0, 1, 2]
seq_manager_1
.add_request(
"request_0".to_string(),
Some(vec![0, 1, 2]),
12, // ISL (3 blocks * 4 block_size)
0, // no overlap
0, // worker_id
WorkerWithDpRank::new(0, 0),
)
.await?;
// Add request_1 to worker 1: sequence [3, 4]
// Add request_1 to worker 0, dp_rank 1: sequence [3, 4]
seq_manager_1
.add_request(
"request_1".to_string(),
Some(vec![3, 4]),
8, // ISL (2 blocks * 4 block_size)
0, // no overlap
1, // worker_id
WorkerWithDpRank::new(0, 1),
)
.await?;
// Add request_2 to worker 2: sequence [0, 1, 2, 3] using seq_manager_2
// Add request_2 to worker 1, dp_rank 0: sequence [0, 1, 2, 3] using seq_manager_2
seq_manager_2
.add_request(
"request_2".to_string(),
Some(vec![0, 1, 2, 3]),
16, // ISL (4 blocks * 4 block_size)
0, // no overlap
2, // worker_id
WorkerWithDpRank::new(1, 0),
)
.await?;
......@@ -981,27 +1021,38 @@ mod tests {
let blocks_phase1 = seq_manager_1.active_blocks().await;
let tokens_phase1 = seq_manager_1.active_tokens().await;
// Verify that seq_manager_1 sees all requests including request_2 from thread 2
// Verify that seq_manager_1 sees all requests including request_2 from seq_manager_2
// We now have:
// - Worker 0, dp_rank 0: request_0
// - Worker 0, dp_rank 1: request_1
// - Worker 1, dp_rank 0: request_2
let worker_0_dp0 = WorkerWithDpRank::new(0, 0);
let worker_0_dp1 = WorkerWithDpRank::new(0, 1);
let worker_1_dp0 = WorkerWithDpRank::new(1, 0);
assert_eq!(
blocks_phase1[&0], 3,
"Worker 0 should have 3 active blocks (from request_0)"
blocks_phase1[&worker_0_dp0], 3,
"Worker 0 dp_rank 0 should have 3 active blocks (from request_0)"
);
assert_eq!(
blocks_phase1[&1], 2,
"Worker 1 should have 2 active blocks (from request_1)"
blocks_phase1[&worker_0_dp1], 2,
"Worker 0 dp_rank 1 should have 2 active blocks (from request_1)"
);
assert_eq!(
blocks_phase1[&2], 4,
"Worker 2 should have 4 active blocks (from request_2 added by seq_manager_2)"
blocks_phase1[&worker_1_dp0], 4,
"Worker 1 dp_rank 0 should have 4 active blocks (from request_2 added by seq_manager_2)"
);
assert_eq!(
tokens_phase1[&0], 12,
"Worker 0 should have 12 active tokens"
tokens_phase1[&worker_0_dp0], 12,
"Worker 0 dp_rank 0 should have 12 active tokens"
);
assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
assert_eq!(
tokens_phase1[&2], 16,
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
tokens_phase1[&worker_0_dp1], 8,
"Worker 0 dp_rank 1 should have 8 active tokens"
);
assert_eq!(
tokens_phase1[&worker_1_dp0], 16,
"Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
);
// PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2
......@@ -1020,17 +1071,23 @@ mod tests {
let blocks_phase2 = seq_manager_2.active_blocks().await;
let tokens_phase2 = seq_manager_2.active_tokens().await;
// Verify phase 2 results - everything should be empty
for worker_id in 0..=2 {
// Verify phase 2 results - everything should be empty for all 3 workers
let all_workers = vec![
WorkerWithDpRank::new(0, 0),
WorkerWithDpRank::new(0, 1),
WorkerWithDpRank::new(1, 0),
];
for worker in all_workers {
assert_eq!(
blocks_phase2[&worker_id], 0,
"Worker {} should have 0 active blocks after all requests freed",
worker_id
blocks_phase2[&worker], 0,
"Worker (id={}, dp_rank={}) should have 0 active blocks after all requests freed",
worker.worker_id, worker.dp_rank
);
assert_eq!(
tokens_phase2[&worker_id], 0,
"Worker {} should have 0 active tokens after all requests freed",
worker_id
tokens_phase2[&worker], 0,
"Worker (id={}, dp_rank={}) should have 0 active tokens after all requests freed",
worker.worker_id, worker.dp_rank
);
}
......@@ -1059,18 +1116,22 @@ mod tests {
// Create multi-worker sequence managers with ALL workers [0, 1, 2]
// Both use the same component to ensure event synchronization works
let worker_ids = vec![0, 1, 2];
let mut workers_with_configs = HashMap::new();
workers_with_configs.insert(0, None);
workers_with_configs.insert(1, None);
workers_with_configs.insert(2, None);
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(),
block_size,
worker_ids.clone(),
workers_with_configs.clone(),
true,
Uuid::new_v4().to_string(),
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
worker_ids,
workers_with_configs,
true,
Uuid::new_v4().to_string(),
));
......@@ -1087,7 +1148,7 @@ mod tests {
None, // No token sequence
12, // ISL (12 tokens)
0, // no overlap
0, // worker_id
WorkerWithDpRank::from_worker_id(0),
)
.await?;
......@@ -1098,7 +1159,7 @@ mod tests {
None, // No token sequence
8, // ISL (8 tokens)
0, // no overlap
1, // worker_id
WorkerWithDpRank::from_worker_id(1),
)
.await?;
......@@ -1109,7 +1170,7 @@ mod tests {
None, // No token sequence
16, // ISL (16 tokens)
0, // no overlap
2, // worker_id
WorkerWithDpRank::from_worker_id(2),
)
.await?;
......@@ -1120,13 +1181,20 @@ mod tests {
let tokens_phase1 = seq_manager_1.active_tokens().await;
// Verify that seq_manager_1 sees all requests including request_2 from thread 2
let worker_0 = WorkerWithDpRank::from_worker_id(0);
let worker_1 = WorkerWithDpRank::from_worker_id(1);
let worker_2 = WorkerWithDpRank::from_worker_id(2);
assert_eq!(
tokens_phase1[&0], 12,
tokens_phase1[&worker_0], 12,
"Worker 0 should have 12 active tokens"
);
assert_eq!(tokens_phase1[&1], 8, "Worker 1 should have 8 active tokens");
assert_eq!(
tokens_phase1[&2], 16,
tokens_phase1[&worker_1], 8,
"Worker 1 should have 8 active tokens"
);
assert_eq!(
tokens_phase1[&worker_2], 16,
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
);
......@@ -1156,8 +1224,9 @@ mod tests {
// Verify phase 2 results - everything should be empty
for worker_id in 0..=2 {
let worker = WorkerWithDpRank::from_worker_id(worker_id);
assert_eq!(
tokens_phase2[&worker_id], 0,
tokens_phase2[&worker], 0,
"Worker {} should have 0 active tokens after all requests freed",
worker_id
);
......
......@@ -23,7 +23,8 @@ use crate::{
kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK,
ROUTER_SNAPSHOT_LOCK,
indexer::{DumpRequest, GetWorkersRequest, RouterEvent, WorkerId},
indexer::{DumpRequest, GetWorkersRequest, RouterEvent},
protocols::WorkerId,
},
};
......
......@@ -211,11 +211,26 @@ impl LocalModelBuilder {
.map(RequestTemplate::load)
.transpose()?;
// Override runtime configs with mocker engine args (applies to both paths)
if self.is_mocker
&& let Some(path) = &self.extra_engine_args
{
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.kv_cache_block_size = mocker_engine_args.block_size as u32;
self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
}
// frontend and echo engine don't need a path.
if self.model_path.is_none() {
let mut card = ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
);
card.kv_cache_block_size = self.kv_cache_block_size;
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone();
......@@ -266,18 +281,6 @@ impl LocalModelBuilder {
card.context_length = context_length;
}
// Override runtime configs with mocker engine args
if self.is_mocker
&& let Some(path) = &self.extra_engine_args
{
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
}
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();
card.runtime_config = self.runtime_config.clone();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment