Unverified Commit c3fcfdd6 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: consolidations of KvPushRouter bindings and usage examples (#3543)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 65cc5337
...@@ -316,6 +316,24 @@ To manage stream growth, when the message count exceeds `--router-snapshot-thres ...@@ -316,6 +316,24 @@ To manage stream growth, when the message count exceeds `--router-snapshot-thres
Instead of launching the KV Router via command line, you can create a `KvPushRouter` object directly in Python. This allows per-request routing configuration overrides. Instead of launching the KV Router via command line, you can create a `KvPushRouter` object directly in Python. This allows per-request routing configuration overrides.
### Methods
The `KvPushRouter` provides the following methods:
- **`generate(token_ids, model, ...)`**: Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.
- **`best_worker_id(token_ids, router_config_override=None, request_id=None)`**: Query which worker would be selected for given tokens. Returns `(worker_id, overlap_blocks)`.
- Without `request_id`: Query-only, doesn't update router state
- With `request_id`: Updates router state to track the request. **Note**: If used with `request_id`, you must call `mark_prefill_complete()` and `free()` at the appropriate lifecycle points to maintain accurate load tracking
- **`get_potential_loads(token_ids)`**: Get detailed load information for all workers, including potential prefill tokens and active decode blocks. Returns a list of load dictionaries.
- **`mark_prefill_complete(request_id)`**: Signal that a request has completed its prefill phase. Only used for [manual lifecycle management](#2-manual-state-management-advanced) when using `best_worker_id()` for manual routing instead of `generate()`.
- **`free(request_id)`**: Signal that a request has completed and its resources should be released. Only used for [manual lifecycle management](#2-manual-state-management-advanced) when using `best_worker_id()` for manual routing instead of `generate()`.
- **`dump_events()`**: Dump all KV cache events from the router's indexer as a JSON string. Useful for debugging and analysis.
### Setup ### Setup
First, launch your backend engines: First, launch your backend engines:
...@@ -377,15 +395,60 @@ if __name__ == "__main__": ...@@ -377,15 +395,60 @@ if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())
``` ```
### Additional Routing Features ### Routing Patterns
The `KvPushRouter` supports multiple usage patterns depending on your control requirements:
#### 1. Automatic Routing (Recommended)
Call `generate()` directly and let the router handle everything:
```python
stream = await router.generate(token_ids=tokens, model="model-name")
```
- **Best for**: Most use cases
- **Router automatically**: Selects best worker, updates state, routes request, tracks lifecycle
#### 2. Manual State Management (Advanced)
Use `best_worker_id(request_id=...)` to select and track, then manage the request yourself:
```python
worker_id, overlap = await router.best_worker_id(tokens, request_id="req-123")
response = await client.generate(tokens, request_id="req-123")
# await anext(response) # Get first token
await router.mark_prefill_complete("req-123") # After first token
# async for _ in response: # Continue generating
# ...
await router.free("req-123") # After completion
```
- **Best for**: Custom request handling with router state tracking
- **Requires**: Calling `mark_prefill_complete()` and `free()` at correct lifecycle points
- **Caution**: Incorrect lifecycle management degrades load balancing accuracy
#### 3. Hierarchical Router Probing
Query without state updates, then route through a chosen router:
```python
# Probe multiple routers without updating state
worker_id_1, overlap_1 = await router_1.best_worker_id(tokens) # No request_id
worker_id_2, overlap_2 = await router_2.best_worker_id(tokens)
The `KvPushRouter` provides additional methods for fine-grained control: # Pick the best router based on results
chosen_router = router_1 if overlap_1 > overlap_2 else router_2
stream = await chosen_router.generate(tokens, model="model-name", worker_id=worker_id)
```
- **Best for**: Multi-tier deployments (e.g., Envoy Gateway routing to multiple router groups)
- **Advantage**: Query multiple routers before committing to one
- **`best_worker_id()`**: Query which worker would be selected for given tokens without actually routing the request. Returns `(worker_id, overlap_blocks)`. #### 4. Custom Load-Based Routing
- **`get_potential_loads()`**: Get detailed load information for all workers including potential prefill tokens and active decode blocks. Use `get_potential_loads()` to implement custom routing logic:
- **`worker_id` parameter in `generate()`**: Force routing to a specific worker by passing `worker_id=<id>` to bypass the automatic KV-aware selection. ```python
loads = await router.get_potential_loads(tokens)
# Apply custom logic (e.g., weighted scoring, constraints)
best_worker = min(loads, key=lambda x: custom_cost_fn(x))
stream = await router.generate(tokens, model="model-name", worker_id=best_worker['worker_id'])
```
- **Best for**: Custom optimization strategies beyond the built-in cost function
- **Advantage**: Full control over worker selection logic
- **See also**: Detailed example below in "Custom Routing Example: Minimizing TTFT"
The `router_config_override` parameter allows you to adjust routing behavior per request without recreating the router. This is useful for implementing different routing strategies based on request characteristics. All patterns support `router_config_override` to adjust routing behavior per-request without recreating the router.
### Custom Routing Example: Minimizing TTFT ### Custom Routing Example: Minimizing TTFT
......
...@@ -1052,12 +1052,13 @@ impl KvPushRouter { ...@@ -1052,12 +1052,13 @@ impl KvPushRouter {
Self::process_request_to_stream(py, self.inner.clone(), request) Self::process_request_to_stream(py, self.inner.clone(), request)
} }
#[pyo3(signature = (token_ids, router_config_override=None))] #[pyo3(signature = (token_ids, router_config_override=None, request_id=None))]
fn best_worker_id<'p>( fn best_worker_id<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
request_id: Option<String>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override { let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| { Python::with_gil(|py| {
...@@ -1069,11 +1070,17 @@ impl KvPushRouter { ...@@ -1069,11 +1070,17 @@ impl KvPushRouter {
None None
}; };
let inner = self.inner.clone(); let chooser = self.inner.chooser.clone();
let update_states = request_id.is_some();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = inner let (worker_id, overlap_blocks) = chooser
.find_best_match(&token_ids, router_config_override.as_ref()) .find_best_match(
request_id.as_deref(),
&token_ids,
router_config_override.as_ref(),
update_states,
)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -1082,15 +1089,42 @@ impl KvPushRouter { ...@@ -1082,15 +1089,42 @@ impl KvPushRouter {
}) })
} }
/// Mark prefill as completed for a request
fn mark_prefill_complete<'p>(
&self,
py: Python<'p>,
request_id: String,
) -> PyResult<Bound<'p, PyAny>> {
let chooser = self.inner.chooser.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
chooser
.mark_prefill_completed(&request_id)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
/// Free a request by its ID, signaling the router to release resources
fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
let chooser = self.inner.chooser.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
chooser.free(&request_id).await.map_err(to_pyerr)?;
Ok(())
})
}
fn get_potential_loads<'p>( fn get_potential_loads<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone(); let chooser = self.inner.chooser.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let loads = inner let loads = chooser
.get_potential_loads(&token_ids) .get_potential_loads(&token_ids)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -1106,10 +1140,10 @@ impl KvPushRouter { ...@@ -1106,10 +1140,10 @@ impl KvPushRouter {
/// Dump all events from the KV router's indexer as a JSON string /// Dump all events from the KV router's indexer as a JSON string
fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> { fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone(); let chooser = self.inner.chooser.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let events = inner.dump_events().await.map_err(to_pyerr)?; let events = chooser.dump_events().await.map_err(to_pyerr)?;
// Serialize to JSON string // Serialize to JSON string
let json_str = serde_json::to_string(&events).map_err(to_pyerr)?; let json_str = serde_json::to_string(&events).map_err(to_pyerr)?;
Ok(json_str) Ok(json_str)
......
...@@ -1221,13 +1221,17 @@ class KvPushRouter: ...@@ -1221,13 +1221,17 @@ class KvPushRouter:
self, self,
token_ids: List[int], token_ids: List[int],
router_config_override: Optional[JsonLike] = None, router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
""" """
Find the best matching worker for the given tokens without updating states. Find the best matching worker for the given tokens.
Args: Args:
token_ids: List of token IDs to find matches for token_ids: List of token IDs to find matches for
router_config_override: Optional router configuration override router_config_override: Optional router configuration override
request_id: Optional request ID. If provided, router states will be updated
to track this request (active blocks, lifecycle events). If not
provided, this is a query-only operation that doesn't affect state.
Returns: Returns:
A tuple of (worker_id, overlap_blocks) where: A tuple of (worker_id, overlap_blocks) where:
...@@ -1263,6 +1267,40 @@ class KvPushRouter: ...@@ -1263,6 +1267,40 @@ class KvPushRouter:
""" """
... ...
async def mark_prefill_complete(self, request_id: str) -> None:
"""
Mark prefill as completed for a request.
This signals that the request has finished its prefill phase and is now
in the decode phase. Used to update router state for accurate load tracking.
Args:
request_id: The ID of the request that completed prefill
Note:
This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing.
"""
...
async def free(self, request_id: str) -> None:
"""
Free a request by its ID, signaling the router to release resources.
This should be called when a request completes to update the router's
tracking of active blocks and ensure accurate load balancing.
Args:
request_id: The ID of the request to free
Note:
This is typically called automatically by the router when using the
`generate()` method. Only call this manually if you're using
`best_worker_id()` with `request_id` for custom routing.
"""
...
class EntrypointArgs: class EntrypointArgs:
""" """
Settings to connect an input to a worker and run them. Settings to connect an input to a worker and run them.
......
...@@ -472,7 +472,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -472,7 +472,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
pub struct KvPushRouter { pub struct KvPushRouter {
inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>, inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
chooser: Arc<KvRouter>, pub chooser: Arc<KvRouter>,
} }
impl KvPushRouter { impl KvPushRouter {
...@@ -482,27 +482,6 @@ impl KvPushRouter { ...@@ -482,27 +482,6 @@ impl KvPushRouter {
) -> Self { ) -> Self {
KvPushRouter { inner, chooser } KvPushRouter { inner, chooser }
} }
/// Find the best matching worker for the given tokens without updating states
pub async fn find_best_match(
&self,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
) -> Result<(i64, u32)> {
self.chooser
.find_best_match(None, tokens, router_config_override, false)
.await
}
/// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
self.chooser.get_potential_loads(tokens).await
}
/// Dump all events from the KV router's indexer
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.chooser.dump_events().await
}
} }
#[async_trait] #[async_trait]
......
...@@ -973,6 +973,12 @@ impl KvIndexerInterface for KvIndexer { ...@@ -973,6 +973,12 @@ impl KvIndexerInterface for KvIndexer {
} }
} }
impl Drop for KvIndexer {
fn drop(&mut self) {
self.shutdown();
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ShardedMatchRequest { pub struct ShardedMatchRequest {
sequence: Vec<LocalBlockHash>, sequence: Vec<LocalBlockHash>,
...@@ -1249,6 +1255,12 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -1249,6 +1255,12 @@ impl KvIndexerInterface for KvIndexerSharded {
} }
} }
impl Drop for KvIndexerSharded {
fn drop(&mut self) {
self.shutdown();
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment