Unverified Commit 149bf5a2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(router): add best_worker indexer update flag (#7603)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent daa4fb04
......@@ -123,6 +123,7 @@ config = {"overlap_score_weight": 2.0} if len(token_ids) > 8192 else {}
worker_id, dp_rank, overlap = await router.best_worker(
token_ids,
request_id="req-123",
update_indexer=True,
router_config_override=config
)
......
......@@ -19,9 +19,10 @@ The `KvRouter` provides the following methods:
- **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.
- **`best_worker(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, dp_rank, overlap_blocks)`.
- **`best_worker(token_ids, router_config_override=None, request_id=None, update_indexer=False)`**: Query which worker would be selected for given tokens. Returns `(worker_id, dp_rank, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
- With `request_id`: Updates router lifecycle state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
- With `update_indexer=True`: Records the selected worker in the approximate indexer for future overlap predictions. This is only meaningful when `use_kv_events=False`
- **`get_potential_loads(token_ids)`**: Get detailed load information for all workers, including potential prefill tokens and active decode blocks. Returns a list of load dictionaries.
......@@ -165,7 +166,11 @@ stream = await router.generate(token_ids=tokens, model="model-name")
### 2. Manual State Management (Advanced)
Use `best_worker(request_id=...)` to select and track, then manage the request yourself:
```python
worker_id, _dp_rank, overlap = await router.best_worker(tokens, request_id="req-123")
worker_id, _dp_rank, overlap = await router.best_worker(
tokens,
request_id="req-123",
update_indexer=True, # needed for approximate mode (use_kv_events=False)
)
response = await client.generate(tokens, request_id="req-123")
# await anext(response) # Get first token
await router.mark_prefill_complete("req-123") # After first token
......@@ -175,6 +180,7 @@ await router.free("req-123") # After completion
```
- **Best for**: Custom request handling with router state tracking
- **Requires**: Calling `mark_prefill_complete()` and `free()` at correct lifecycle points
- **Approximate mode**: Pass `update_indexer=True` when `use_kv_events=False` so the router learns from manual worker selections
- **Caution**: Incorrect lifecycle management degrades load balancing accuracy
### 3. Hierarchical Router Probing
......
......@@ -1048,13 +1048,15 @@ impl KvRouter {
Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
}
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None, lora_name=None))]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None, update_indexer=false, block_mm_infos=None, lora_name=None))]
fn best_worker<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
router_config_override: Option<PyObject>,
request_id: Option<String>,
update_indexer: bool,
block_mm_infos: Option<PyObject>,
lora_name: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
......@@ -1089,6 +1091,13 @@ impl KvRouter {
.await
.map_err(to_pyerr)?;
if update_indexer && !chooser.kv_router_config().use_kv_events {
chooser
.record_routing_decision(token_ids.clone(), best_worker)
.await
.map_err(to_pyerr)?;
}
Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
})
}
......
......@@ -1707,6 +1707,7 @@ class KvRouter:
token_ids: List[int],
router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None,
update_indexer: bool = False,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
) -> Tuple[int, int, int]:
......@@ -1719,6 +1720,10 @@ class KvRouter:
request_id: Optional request ID. If provided, router states will be updated
to track this request (active blocks, lifecycle events). If not
provided, this is a query-only operation that doesn't affect state.
update_indexer: Whether to record the selected worker in the router's
approximate indexer. This is only meaningful when
`use_kv_events=False` and is independent from lifecycle
state tracking via `request_id`.
block_mm_infos: Optional block-level multimodal metadata aligned to request
blocks. When provided, this is used in block hash computation
to enable MM-aware worker selection.
......
......@@ -386,6 +386,17 @@ where
&self.kv_router_config
}
pub async fn record_routing_decision(
&self,
tokens: Vec<u32>,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let mut tokens_with_hashes = TokensWithHashes::new(tokens, self.block_size);
self.indexer
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
}
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking.
......
......@@ -4,7 +4,7 @@
use std::sync::Arc;
use anyhow::Result;
use dynamo_kv_router::protocols::{TokensWithHashes, WorkerWithDpRank};
use dynamo_kv_router::protocols::WorkerWithDpRank;
use dynamo_runtime::{
dynamo_nvtx_range,
pipeline::{
......@@ -386,12 +386,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config().use_kv_events {
let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let mut tokens_with_hashes =
TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size());
if let Err(e) = self
.chooser
.indexer()
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.record_routing_decision(request.token_ids.clone(), worker)
.await
{
tracing::warn!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment