Unverified Commit 149bf5a2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(router): add best_worker indexer update flag (#7603)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent daa4fb04
...@@ -123,6 +123,7 @@ config = {"overlap_score_weight": 2.0} if len(token_ids) > 8192 else {} ...@@ -123,6 +123,7 @@ config = {"overlap_score_weight": 2.0} if len(token_ids) > 8192 else {}
worker_id, dp_rank, overlap = await router.best_worker( worker_id, dp_rank, overlap = await router.best_worker(
token_ids, token_ids,
request_id="req-123", request_id="req-123",
update_indexer=True,
router_config_override=config router_config_override=config
) )
......
...@@ -19,9 +19,10 @@ The `KvRouter` provides the following methods: ...@@ -19,9 +19,10 @@ The `KvRouter` provides the following methods:
- **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management. - **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.
- **`best_worker(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, dp_rank, overlap_blocks)`. - **`best_worker(token_ids, router_config_override=None, request_id=None, update_indexer=False)`**: Query which worker would be selected for given tokens. Returns `(worker_id, dp_rank, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state - Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking - With `request_id`: Updates router lifecycle state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
- With `update_indexer=True`: Records the selected worker in the approximate indexer for future overlap predictions. This is only meaningful when `use_kv_events=False`
- **`get_potential_loads(token_ids)`**: Get detailed load information for all workers, including potential prefill tokens and active decode blocks. Returns a list of load dictionaries. - **`get_potential_loads(token_ids)`**: Get detailed load information for all workers, including potential prefill tokens and active decode blocks. Returns a list of load dictionaries.
...@@ -165,7 +166,11 @@ stream = await router.generate(token_ids=tokens, model="model-name") ...@@ -165,7 +166,11 @@ stream = await router.generate(token_ids=tokens, model="model-name")
### 2. Manual State Management (Advanced) ### 2. Manual State Management (Advanced)
Use `best_worker(request_id=...)` to select and track, then manage the request yourself: Use `best_worker(request_id=...)` to select and track, then manage the request yourself:
```python ```python
worker_id, _dp_rank, overlap = await router.best_worker(tokens, request_id="req-123") worker_id, _dp_rank, overlap = await router.best_worker(
tokens,
request_id="req-123",
update_indexer=True, # needed for approximate mode (use_kv_events=False)
)
response = await client.generate(tokens, request_id="req-123") response = await client.generate(tokens, request_id="req-123")
# await anext(response) # Get first token # await anext(response) # Get first token
await router.mark_prefill_complete("req-123") # After first token await router.mark_prefill_complete("req-123") # After first token
...@@ -175,6 +180,7 @@ await router.free("req-123") # After completion ...@@ -175,6 +180,7 @@ await router.free("req-123") # After completion
``` ```
- **Best for**: Custom request handling with router state tracking - **Best for**: Custom request handling with router state tracking
- **Requires**: Calling `mark_prefill_complete()` and `free()` at correct lifecycle points - **Requires**: Calling `mark_prefill_complete()` and `free()` at correct lifecycle points
- **Approximate mode**: Pass `update_indexer=True` when `use_kv_events=False` so the router learns from manual worker selections
- **Caution**: Incorrect lifecycle management degrades load balancing accuracy - **Caution**: Incorrect lifecycle management degrades load balancing accuracy
### 3. Hierarchical Router Probing ### 3. Hierarchical Router Probing
......
...@@ -1048,13 +1048,15 @@ impl KvRouter { ...@@ -1048,13 +1048,15 @@ impl KvRouter {
Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker)) Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
} }
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None, lora_name=None))] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None, update_indexer=false, block_mm_infos=None, lora_name=None))]
fn best_worker<'p>( fn best_worker<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
request_id: Option<String>, request_id: Option<String>,
update_indexer: bool,
block_mm_infos: Option<PyObject>, block_mm_infos: Option<PyObject>,
lora_name: Option<String>, lora_name: Option<String>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
...@@ -1089,6 +1091,13 @@ impl KvRouter { ...@@ -1089,6 +1091,13 @@ impl KvRouter {
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
if update_indexer && !chooser.kv_router_config().use_kv_events {
chooser
.record_routing_decision(token_ids.clone(), best_worker)
.await
.map_err(to_pyerr)?;
}
Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks)) Ok((best_worker.worker_id, best_worker.dp_rank, overlap_blocks))
}) })
} }
......
...@@ -1707,6 +1707,7 @@ class KvRouter: ...@@ -1707,6 +1707,7 @@ class KvRouter:
token_ids: List[int], token_ids: List[int],
router_config_override: Optional[JsonLike] = None, router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None, request_id: Optional[str] = None,
update_indexer: bool = False,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None, block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None, lora_name: Optional[str] = None,
) -> Tuple[int, int, int]: ) -> Tuple[int, int, int]:
...@@ -1719,6 +1720,10 @@ class KvRouter: ...@@ -1719,6 +1720,10 @@ class KvRouter:
request_id: Optional request ID. If provided, router states will be updated request_id: Optional request ID. If provided, router states will be updated
to track this request (active blocks, lifecycle events). If not to track this request (active blocks, lifecycle events). If not
provided, this is a query-only operation that doesn't affect state. provided, this is a query-only operation that doesn't affect state.
update_indexer: Whether to record the selected worker in the router's
approximate indexer. This is only meaningful when
`use_kv_events=False` and is independent from lifecycle
state tracking via `request_id`.
block_mm_infos: Optional block-level multimodal metadata aligned to request block_mm_infos: Optional block-level multimodal metadata aligned to request
blocks. When provided, this is used in block hash computation blocks. When provided, this is used in block hash computation
to enable MM-aware worker selection. to enable MM-aware worker selection.
......
...@@ -386,6 +386,17 @@ where ...@@ -386,6 +386,17 @@ where
&self.kv_router_config &self.kv_router_config
} }
pub async fn record_routing_decision(
&self,
tokens: Vec<u32>,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let mut tokens_with_hashes = TokensWithHashes::new(tokens, self.block_size);
self.indexer
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
}
/// Give these tokens, find the worker with the best match in it's KV cache. /// Give these tokens, find the worker with the best match in it's KV cache.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks. /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking. /// Now also takes optional context_id for request tracking.
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::protocols::{TokensWithHashes, WorkerWithDpRank}; use dynamo_kv_router::protocols::WorkerWithDpRank;
use dynamo_runtime::{ use dynamo_runtime::{
dynamo_nvtx_range, dynamo_nvtx_range,
pipeline::{ pipeline::{
...@@ -386,12 +386,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -386,12 +386,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// This covers both pre-selected workers and find_best_match selections. // This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config().use_kv_events { if !is_query_only && !self.chooser.kv_router_config().use_kv_events {
let worker = WorkerWithDpRank::new(instance_id, dp_rank); let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let mut tokens_with_hashes =
TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size());
if let Err(e) = self if let Err(e) = self
.chooser .chooser
.indexer() .record_routing_decision(request.token_ids.clone(), worker)
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await .await
{ {
tracing::warn!( tracing::warn!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment