Unverified Commit 383e3b3a authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: don't modify kv scheduler states on query + more python binding (#2798)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 7da510cf
......@@ -292,4 +292,70 @@ if __name__ == "__main__":
asyncio.run(main())
```
### Additional Routing Features
The `KvPushRouter` provides additional methods for fine-grained control:
- **`best_worker_id()`**: Query which worker would be selected for given tokens without actually routing the request. Returns `(worker_id, overlap_blocks)`.
- **`get_potential_loads()`**: Get detailed load information for all workers including potential prefill tokens and active decode blocks.
- **`worker_id` parameter in `generate()`**: Force routing to a specific worker by passing `worker_id=<id>` to bypass the automatic KV-aware selection.
The `router_config_override` parameter allows you to adjust routing behavior per request without recreating the router. This is useful for implementing different routing strategies based on request characteristics.
### Custom Routing Example: Minimizing TTFT
Here's an example of using `get_potential_loads()` to implement custom routing that minimizes Time To First Token (TTFT) by selecting the worker with the least prefill work:
```python
import asyncio
from dynamo._core import DistributedRuntime, KvPushRouter, KvRouterConfig
async def minimize_ttft_routing():
# Setup router
runtime = DistributedRuntime.detached()
namespace = runtime.namespace("inference")
component = namespace.component("vllm")
endpoint = component.endpoint("generate")
router = KvPushRouter(
endpoint=endpoint,
block_size=16,
kv_router_config=KvRouterConfig()
)
# Your input tokens
token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Get potential loads for all workers
potential_loads = await router.get_potential_loads(token_ids)
# Find worker with minimum prefill tokens (best for TTFT)
best_worker = min(potential_loads, key=lambda x: x['potential_prefill_tokens'])
print(f"Worker loads: {potential_loads}")
print(f"Selected worker {best_worker['worker_id']} with {best_worker['potential_prefill_tokens']} prefill tokens")
# Route directly to the selected worker
stream = await router.generate(
token_ids=token_ids,
model="meta-llama/Llama-2-7b-hf",
worker_id=best_worker['worker_id'], # Force routing to optimal worker
stop_conditions={"max_tokens": 20}
)
# Process response
async for response in stream:
if isinstance(response, dict) and "token_ids" in response:
print(f"Generated tokens: {response['token_ids']}")
if __name__ == "__main__":
asyncio.run(minimize_ttft_routing())
```
This approach gives you complete control over routing decisions, allowing you to optimize for different metrics based on your specific requirements. As some examples:
- **Minimize TTFT**: Select worker with lowest `potential_prefill_tokens`
- **Maximize cache reuse**: Use `best_worker_id()` which considers both prefill and decode loads
- **Balance load**: Consider both `potential_prefill_tokens` and `potential_decode_blocks` together
See [KV Router Architecture](../components/router/README.md) for performance tuning details.
......@@ -909,7 +909,7 @@ impl KvPushRouter {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None))]
#[pyo3(signature = (token_ids, model, stop_conditions=None, sampling_options=None, output_options=None, router_config_override=None, worker_id=None))]
fn generate<'p>(
&self,
py: Python<'p>,
......@@ -919,6 +919,7 @@ impl KvPushRouter {
sampling_options: Option<PyObject>,
output_options: Option<PyObject>,
router_config_override: Option<PyObject>,
worker_id: Option<i64>,
) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the options with defaults
let (stop_conditions, sampling_options, output_options, router_config_override) =
......@@ -957,15 +958,22 @@ impl KvPushRouter {
})?;
// Build the PreprocessedRequest
let request = llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder()
let mut request_builder =
llm_rs::protocols::common::preprocessor::PreprocessedRequest::builder();
request_builder
.model(model)
.token_ids(token_ids)
.stop_conditions(stop_conditions)
.sampling_options(sampling_options)
.output_options(output_options)
.router_config_override(router_config_override)
.build()
.map_err(to_pyerr)?;
.router_config_override(router_config_override);
// Set backend_instance_id if worker_id is provided
if let Some(worker_id) = worker_id {
request_builder.backend_instance_id(Some(worker_id));
}
let request = request_builder.build().map_err(to_pyerr)?;
let inner = self.inner.clone();
......@@ -1010,6 +1018,59 @@ impl KvPushRouter {
})
}
#[pyo3(signature = (context_id, token_ids, router_config_override=None))]
fn best_worker_id<'p>(
&self,
py: Python<'p>,
context_id: String,
token_ids: Vec<u32>,
router_config_override: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
depythonize(obj.bind(py)).map_err(to_pyerr)?;
Ok::<_, PyErr>(Some(override_config))
})?
} else {
None
};
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = inner
.find_best_match(&context_id, &token_ids, router_config_override.as_ref())
.await
.map_err(to_pyerr)?;
// Return a tuple of (worker_id, overlap_blocks)
Ok((worker_id, overlap_blocks))
})
}
fn get_potential_loads<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let loads = inner
.get_potential_loads(&token_ids)
.await
.map_err(to_pyerr)?;
// Use pythonize to convert Vec<PotentialLoad> to Python list of dicts
Python::with_gil(|py| {
pythonize(py, &loads)
.map(|obj| obj.unbind())
.map_err(to_pyerr)
})
})
}
/// Dump all events from the KV router's indexer as a JSON string
fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
......
......@@ -1227,6 +1227,7 @@ class KvPushRouter:
sampling_options: Optional[JsonLike] = None,
output_options: Optional[JsonLike] = None,
router_config_override: Optional[JsonLike] = None,
worker_id: Optional[int] = None,
) -> AsyncIterator[JsonLike]:
"""
Generate text using the KV-aware router.
......@@ -1238,9 +1239,56 @@ class KvPushRouter:
sampling_options: Optional sampling configuration
output_options: Optional output configuration
router_config_override: Optional router configuration override
worker_id: Optional worker ID to route to directly. If set, the request
will be sent to this specific worker and router states will be
updated accordingly.
Returns:
An async iterator yielding generation responses
Note:
- If worker_id is set, the request bypasses KV matching and routes directly
to the specified worker while still updating router states.
- This is different from query_instance_id which doesn't route the request.
"""
...
async def best_worker_id(
self,
context_id: str,
token_ids: List[int],
router_config_override: Optional[JsonLike] = None,
) -> Tuple[int, int]:
"""
Find the best matching worker for the given tokens without updating states.
Args:
context_id: String identifier for the request
token_ids: List of token IDs to find matches for
router_config_override: Optional router configuration override
Returns:
A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def get_potential_loads(
self,
token_ids: List[int],
) -> List[Dict[str, int]]:
"""
Get potential prefill and decode loads for all workers.
Args:
token_ids: List of token IDs to evaluate
Returns:
A list of dictionaries, each containing:
- worker_id: The worker ID
- potential_prefill_tokens: Number of tokens that would need prefill
- potential_decode_blocks: Number of blocks currently in decode phase
"""
...
......
......@@ -41,7 +41,7 @@ use crate::{
compute_block_hash_for_seq, compute_seq_hash_for_block,
},
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
scoring::ProcessedEndpoints,
subscriber::start_kv_router_background,
},
......@@ -287,6 +287,7 @@ impl KvRouter {
context_id: &str,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
) -> anyhow::Result<(i64, u32)> {
let isl_tokens = tokens.len();
......@@ -303,6 +304,7 @@ impl KvRouter {
seq_hashes.clone(),
overlap_scores.clone(),
router_config_override,
update_states,
)
.await?;
......@@ -321,6 +323,28 @@ impl KvRouter {
Ok((best_worker_id, overlap_amount))
}
pub async fn add_request(
&self,
request_id: String,
tokens: &[u32],
overlap_blocks: u32,
worker_id: i64,
) {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
self.scheduler
.add_request(
request_id,
seq_hashes,
isl_tokens,
overlap_blocks,
worker_id,
)
.await;
}
pub async fn mark_prefill_completed(&self, request_id: &str) {
self.scheduler.mark_prefill_completed(request_id).await
}
......@@ -333,6 +357,19 @@ impl KvRouter {
self.block_size
}
/// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(self
.scheduler
.get_potential_loads(seq_hashes, isl_tokens, overlap_scores)
.await)
}
/// Dump all events from the indexer
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
......@@ -348,7 +385,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
) -> Result<ManyOut<Annotated<RouterResponse>>> {
let (request, ctx) = request.into_parts();
let (worker_id, _) = self
.find_best_match(ctx.id(), &request.tokens, None)
.find_best_match(ctx.id(), &request.tokens, None, true)
.await?;
let response = RouterResponse { worker_id };
......@@ -371,6 +408,23 @@ impl KvPushRouter {
KvPushRouter { inner, chooser }
}
/// Find the best matching worker for the given tokens without updating states
pub async fn find_best_match(
&self,
context_id: &str,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
) -> Result<(i64, u32)> {
self.chooser
.find_best_match(context_id, tokens, router_config_override, false)
.await
}
/// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
self.chooser.get_potential_loads(tokens).await
}
/// Dump all events from the KV router's indexer
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.chooser.dump_events().await
......@@ -381,6 +435,25 @@ impl KvPushRouter {
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for KvPushRouter
{
/// Generate method that handles KV-aware routing with three distinct behaviors:
///
/// 1. **If `query_instance_id` annotation is set**:
/// - Returns the best matching worker ID without routing the request
/// - Does NOT update any router local states
/// - Response includes worker_instance_id and token_data annotations
///
/// 2. **If `backend_instance_id` is set in the request**:
/// - Routes directly to the specified backend instance
/// - DOES update router states to track this request (unless query_instance_id is also set)
/// - Bypasses the normal KV matching logic
///
/// 3. **If neither are set (default behavior)**:
/// - Finds the best worker based on KV cache overlap
/// - Updates router states to track the request
/// - Routes to the selected worker
///
/// The router state updates include tracking active sequences and managing
/// prefill/completion lifecycle for proper KV cache management.
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
......@@ -390,8 +463,17 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
InstanceSource::Dynamic(_) => {
// Extract context ID for request tracking
let context_id = request.context().id().to_string();
// Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id");
let (instance_id, overlap_amount) = if let Some(id) = request.backend_instance_id {
// If instance_id is set, use it
// If instance_id is set, use it and manually add the request to track it
if !query_instance_id {
self.chooser
.add_request(context_id.clone(), &request.token_ids, 0, id)
.await;
}
(id, 0)
} else {
// Otherwise, find the best match
......@@ -400,17 +482,15 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
&context_id,
&request.token_ids,
request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id
)
.await?
};
let query_instance_id = request.has_annotation("query_instance_id");
// Extract context information before moving the request
// if request has the annotation "query_instance_id",
// then the request will not be routed to the worker,
// and instead the worker_instance_id will be returned.
let stream_context = request.context().clone();
// if request has the annotation "query_instance_id", for example
// curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}'
// request will not be routed to worker immediately.
// The gateway EPP will receive the worker_instance_id and the tokens.
if query_instance_id {
let instance_id_str = instance_id.to_string();
let response =
......@@ -426,7 +506,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream = stream::iter(vec![response, response_tokens]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
// Update the request with the estimated prefix hit blocks
let (mut backend_input, context) = request.into_parts();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
......
......@@ -6,7 +6,7 @@ use dynamo_runtime::component::{Component, Instance};
use dynamo_runtime::traits::events::EventPublisher;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
......@@ -28,6 +28,13 @@ pub struct KVHitRateEvent {
pub overlap_blocks: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
pub worker_id: i64,
pub potential_prefill_tokens: usize,
pub potential_decode_blocks: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
#[error("no endpoints aviailable to route work")]
......@@ -55,6 +62,8 @@ pub struct SchedulingRequest {
pub prefill_tokens: HashMap<i64, usize>,
// Router config overrides for this specific request
pub router_config_override: Option<RouterConfigOverride>,
// Whether to update scheduler states (false for query_instance_id requests)
pub update_states: bool,
// Option to take it out to send the response without moving the struct
resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
}
......@@ -204,15 +213,18 @@ impl KvScheduler {
};
request.respond(response);
let _ = slots_clone
.add_request(
request.request_id,
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
)
.await;
// Only update the state if update_states is true
if request.update_states {
let _ = slots_clone
.add_request(
request.request_id,
request.token_seq,
request.isl_tokens,
selection.overlap_blocks,
selection.worker_id,
)
.await;
}
continue;
}
......@@ -247,6 +259,7 @@ impl KvScheduler {
token_seq: Vec<SequenceHash>,
overlaps: OverlapScores,
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
) -> Result<i64, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
......@@ -257,6 +270,7 @@ impl KvScheduler {
decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(),
update_states,
resp_tx: Some(resp_tx), // Wrap in Some()
};
......@@ -272,6 +286,20 @@ impl KvScheduler {
Ok(best_worker_id)
}
pub async fn add_request(
&self,
request_id: String,
token_sequence: Vec<SequenceHash>,
isl: usize,
overlap: u32,
worker_id: i64,
) {
let _ = self
.slots
.add_request(request_id, token_sequence, isl, overlap, worker_id)
.await;
}
pub async fn mark_prefill_completed(&self, request_id: &str) {
let _ = self
.slots
......@@ -282,6 +310,38 @@ impl KvScheduler {
pub async fn free(&self, request_id: &str) {
let _ = self.slots.free(&request_id.to_string()).await;
}
pub async fn get_potential_loads(
&self,
token_seq: Vec<SequenceHash>,
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
.await;
// Get all unique worker IDs from both hashmaps
let mut worker_ids: HashSet<i64> = HashSet::new();
worker_ids.extend(decode_blocks.keys().copied());
worker_ids.extend(prefill_tokens.keys().copied());
// Create PotentialLoad for each worker
let mut loads = Vec::new();
for worker_id in worker_ids {
loads.push(PotentialLoad {
worker_id,
potential_prefill_tokens: prefill_tokens
.get(&worker_id)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker_id).copied().unwrap_or(0),
});
}
loads
}
}
// Helper function for softmax sampling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment