"lib/ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "dd6c399565fe203898e14f1d92c87be35f07f24f"
Unverified Commit 13640e15 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: allow using ApproxKvIndexer for routing via use_kv_events flag (#1869)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
parent 61a1f4ff
...@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm. ...@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
Usage: Usage:
``` ```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--verbosity (-v|-vv)] dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--use-kv-events=true] [--verbosity (-v|-vv)]
``` ```
Example: `dynamo run Qwen/Qwen3-0.6B` Example: `dynamo run Qwen/Qwen3-0.6B`
...@@ -201,7 +201,13 @@ The only difference from the distributed system above is `--router-mode kv`. The ...@@ -201,7 +201,13 @@ The only difference from the distributed system above is `--router-mode kv`. The
For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing. For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing.
The argument `--kv-overlap-score-weight` sets the amount weighting on overlaps with prefix caches, which directly contributes to the prefill cost, so a large weight is expected to yield a better TTFT (at the expense of worse ITL). When this is set 0, we do not consider the prefix caches at all (falling back to pure load balancing behavior on the active blocks), in which case we do not require the backend engines to emit any KV events. The argument `--router-temperature` sets the temperature when randomly selecting the workers to route to via softmax sampling on the router cost logits, setting it to 0 recovers the deterministic behavior where the min logit is picked. The KV-aware routing arguments:
- `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks).
- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 recovers the deterministic behavior where the min logit is picked.
- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events.
## Full usage details ## Full usage details
......
...@@ -128,6 +128,13 @@ pub struct Flags { ...@@ -128,6 +128,13 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub router_temperature: Option<f64>, pub router_temperature: Option<f64>,
/// KV Router: Whether to use KV events to maintain the view of cached blocks
/// If false, would use ApproxKvRouter for predicting block creation / deletion
/// based only on incoming requests at a timer.
/// Default: true
#[arg(long)]
pub use_kv_events: Option<bool>,
/// Max model context length. Reduce this if you don't have enough VRAM for the full model /// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4). /// context length (e.g. Llama 4).
/// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json. /// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json.
...@@ -215,6 +222,7 @@ impl Flags { ...@@ -215,6 +222,7 @@ impl Flags {
KvRouterConfig::new( KvRouterConfig::new(
self.kv_overlap_score_weight, self.kv_overlap_score_weight,
self.router_temperature, self.router_temperature,
self.use_kv_events,
self.max_num_batched_tokens, self.max_num_batched_tokens,
), ),
) )
......
...@@ -212,18 +212,12 @@ impl ModelManager { ...@@ -212,18 +212,12 @@ impl ModelManager {
kv_cache_block_size: u32, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> { ) -> anyhow::Result<Arc<KvRouter>> {
// Determine if we should use KV events based on overlap score weight let selector = Box::new(DefaultWorkerSelector::new(kv_router_config.clone()));
let use_kv_events = kv_router_config
.as_ref()
.map(|config| config.overlap_score_weight > 0.0)
.unwrap_or(false);
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new( let chooser = KvRouter::new(
component.clone(), component.clone(),
kv_cache_block_size, kv_cache_block_size,
Some(selector), Some(selector),
use_kv_events, kv_router_config.unwrap_or_default().use_kv_events,
) )
.await?; .await?;
let new_kv_chooser = Arc::new(chooser); let new_kv_chooser = Arc::new(chooser);
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use anyhow::Result; use anyhow::Result;
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -14,6 +15,7 @@ use dynamo_runtime::{ ...@@ -14,6 +15,7 @@ use dynamo_runtime::{
protocols::annotated::Annotated, protocols::annotated::Annotated,
}; };
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use tokio::sync::Mutex;
pub mod approx; pub mod approx;
pub mod indexer; pub mod indexer;
...@@ -27,7 +29,11 @@ pub mod sequence; ...@@ -27,7 +29,11 @@ pub mod sequence;
use crate::{ use crate::{
kv_router::{ kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, approx::ApproxKvIndexer,
indexer::{
compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError,
OverlapScores, RouterEvent,
},
metrics_aggregator::EndpointCollector, metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
...@@ -35,7 +41,6 @@ use crate::{ ...@@ -35,7 +41,6 @@ use crate::{
}, },
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::LLMEngineOutput,
tokens::TokenBlockSequence,
}; };
use dynamo_runtime::traits::events::EventSubscriber; use dynamo_runtime::traits::events::EventSubscriber;
...@@ -63,6 +68,8 @@ pub struct KvRouterConfig { ...@@ -63,6 +68,8 @@ pub struct KvRouterConfig {
pub router_temperature: f64, pub router_temperature: f64,
pub use_kv_events: bool,
// note: this is not actually used for now // note: this is not actually used for now
pub max_num_batched_tokens: u32, pub max_num_batched_tokens: u32,
} }
...@@ -72,6 +79,7 @@ impl Default for KvRouterConfig { ...@@ -72,6 +79,7 @@ impl Default for KvRouterConfig {
Self { Self {
overlap_score_weight: 1.0, overlap_score_weight: 1.0,
router_temperature: 0.5, router_temperature: 0.5,
use_kv_events: true,
max_num_batched_tokens: 8192, max_num_batched_tokens: 8192,
} }
} }
...@@ -83,24 +91,52 @@ impl KvRouterConfig { ...@@ -83,24 +91,52 @@ impl KvRouterConfig {
pub fn new( pub fn new(
overlap_score_weight: Option<f64>, overlap_score_weight: Option<f64>,
temperature: Option<f64>, temperature: Option<f64>,
use_kv_events: Option<bool>,
max_num_batched_tokens: Option<u32>, max_num_batched_tokens: Option<u32>,
) -> Self { ) -> Self {
let default = Self::default(); let default = Self::default();
Self { Self {
overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight), overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
router_temperature: temperature.unwrap_or(default.router_temperature), router_temperature: temperature.unwrap_or(default.router_temperature),
use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
max_num_batched_tokens: max_num_batched_tokens max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens), .unwrap_or(default.max_num_batched_tokens),
} }
} }
} }
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
KvIndexer(KvIndexer),
ApproxKvIndexer(ApproxKvIndexer),
}
impl Indexer {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
}
}
}
/// A KvRouter only decides which worker you should use. It doesn't send you there. /// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route. /// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter { pub struct KvRouter {
indexer: Option<KvIndexer>, indexer: Indexer,
// How about a Box<dyn KvIndexerInterface>
scheduler: KvScheduler, scheduler: KvScheduler,
block_size: u32, block_size: u32,
// To ensure blocking reads / writes
// TODO: benchmark tradeoffs
find_best_match_mutex: Mutex<()>,
} }
impl KvRouter { impl KvRouter {
...@@ -118,8 +154,16 @@ impl KvRouter { ...@@ -118,8 +154,16 @@ impl KvRouter {
let metrics_aggregator = let metrics_aggregator =
EndpointCollector::new(component.clone(), cancellation_token.clone()).await; EndpointCollector::new(component.clone(), cancellation_token.clone()).await;
let maybe_indexer = let indexer = if use_kv_events {
use_kv_events.then(|| KvIndexer::new(cancellation_token.clone(), block_size)); Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
} else {
// hard code 120 seconds for now
Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
cancellation_token.clone(),
block_size,
Duration::from_secs(120),
))
};
let scheduler = KvScheduler::start( let scheduler = KvScheduler::start(
component.namespace().clone(), component.namespace().clone(),
...@@ -131,9 +175,9 @@ impl KvRouter { ...@@ -131,9 +175,9 @@ impl KvRouter {
// [gluo TODO] try subscribe_with_type::<RouterEvent>, // [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different. // error checking below will be different.
if let Some(ref indexer) = maybe_indexer { if let Indexer::KvIndexer(ref kv_indexer) = indexer {
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?; let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_events_tx = indexer.event_sender(); let kv_events_tx = kv_indexer.event_sender();
tokio::spawn(async move { tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await { while let Some(event) = kv_events_rx.next().await {
...@@ -158,9 +202,10 @@ impl KvRouter { ...@@ -158,9 +202,10 @@ impl KvRouter {
tracing::info!("KV Routing initialized"); tracing::info!("KV Routing initialized");
Ok(Self { Ok(Self {
indexer: maybe_indexer, indexer,
scheduler, scheduler,
block_size, block_size,
find_best_match_mutex: Mutex::new(()), // Add this
}) })
} }
...@@ -172,20 +217,15 @@ impl KvRouter { ...@@ -172,20 +217,15 @@ impl KvRouter {
context_id: &str, context_id: &str,
tokens: &[u32], tokens: &[u32],
) -> anyhow::Result<(i64, u32)> { ) -> anyhow::Result<(i64, u32)> {
// Acquire mutex to serialize access
// TODO: may as well make all the subroutines synchronous if benchmarking favors this
let _guard = self.find_best_match_mutex.lock().await;
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_size = self.block_size; let block_size = self.block_size;
let (complete_blocks, _partial_block) = let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
TokenBlockSequence::split_tokens(tokens, block_size, 1337_u64); let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let local_block_hashes = complete_blocks
.into_iter()
.map(|block| LocalBlockHash(block.block_hash()))
.collect();
let overlap_scores = match &self.indexer {
Some(indexer) => indexer.find_matches(local_block_hashes).await?,
None => Default::default(), // Returns empty/default instance
};
let best_worker_id = self let best_worker_id = self
.scheduler .scheduler
...@@ -198,6 +238,13 @@ impl KvRouter { ...@@ -198,6 +238,13 @@ impl KvRouter {
) )
.await?; .await?;
if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
.process_routing_decision_for_request(tokens, best_worker_id)
.await
.unwrap();
};
let overlap_amount = overlap_scores let overlap_amount = overlap_scores
.scores .scores
.get(&best_worker_id) .get(&best_worker_id)
......
...@@ -72,6 +72,7 @@ struct TimerEntry { ...@@ -72,6 +72,7 @@ struct TimerEntry {
struct TimerManager<K: Clone + Hash + Eq + Ord> { struct TimerManager<K: Clone + Hash + Eq + Ord> {
/// The source of truth. Maps a key to its current expiration instant. /// The source of truth. Maps a key to its current expiration instant.
timers: HashMap<K, Instant>, timers: HashMap<K, Instant>,
/// A min-heap of (expiration_instant, key) used to efficiently find the /// A min-heap of (expiration_instant, key) used to efficiently find the
/// next expiring timer. An entry in this heap is "stale" if the instant /// next expiring timer. An entry in this heap is "stale" if the instant
/// does not match the one in the `timers` map. /// does not match the one in the `timers` map.
...@@ -79,18 +80,32 @@ struct TimerManager<K: Clone + Hash + Eq + Ord> { ...@@ -79,18 +80,32 @@ struct TimerManager<K: Clone + Hash + Eq + Ord> {
/// The expiration duration of the timers. /// The expiration duration of the timers.
ttl: Duration, ttl: Duration,
/// Threshold for rebuilding the heap.
/// The heap will be rebuilt from scratch to remove stale entries.
threshold: usize,
} }
impl<K: Clone + Hash + Eq + Ord> TimerManager<K> { impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
/// Creates a new, empty TimerManager. /// Creates a new, empty TimerManager.
pub fn new(ttl: Duration) -> Self { pub fn new(ttl: Duration, threshold: usize) -> Self {
TimerManager { TimerManager {
timers: HashMap::new(), timers: HashMap::new(),
expirations: BinaryHeap::new(), expirations: BinaryHeap::new(),
ttl, ttl,
threshold,
} }
} }
/// Rebuilds the expirations heap from the timers map, removing all stale entries.
fn rebuild_heap(&mut self) {
self.expirations = self
.timers
.iter()
.map(|(key, &expiry)| Reverse((expiry, key.clone())))
.collect();
}
/// Inserts a new timer or updates an existing one for the given key. /// Inserts a new timer or updates an existing one for the given key.
/// ///
/// # Arguments /// # Arguments
...@@ -108,6 +123,11 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> { ...@@ -108,6 +123,11 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
// which will be ignored when it's popped. // which will be ignored when it's popped.
self.expirations.push(Reverse((expiry_time, key))); self.expirations.push(Reverse((expiry_time, key)));
} }
// Check if we should rebuild the heap to remove stale entries
if self.expirations.len() > self.timers.len() * self.threshold {
self.rebuild_heap();
}
} }
/// Polls for expired timers and returns a list of keys for all timers /// Polls for expired timers and returns a list of keys for all timers
...@@ -123,23 +143,12 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> { ...@@ -123,23 +143,12 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
} }
// The timer might be expired, so pop it from the heap. // The timer might be expired, so pop it from the heap.
// We can safely unwrap because we just peeked.
let Reverse((expiry_time, key)) = self.expirations.pop().unwrap(); let Reverse((expiry_time, key)) = self.expirations.pop().unwrap();
// CRUCIAL STEP: Check if the popped timer is stale. if self.timers.get(&key) == Some(&expiry_time) {
// A timer is stale if its key is no longer in our authoritative map, // This is a valid, non-stale, expired timer.
// or if the expiration time in the map is different (i.e., it was updated). self.timers.remove(&key);
match self.timers.get(&key) { expired_keys.push(key);
Some(authoritative_expiry) if *authoritative_expiry == expiry_time => {
// This is a valid, non-stale, expired timer.
// Remove it from the map and add its key to our results.
self.timers.remove(&key);
expired_keys.push(key);
}
_ => {
// This entry in the heap was stale. It was either removed
// or updated with a new time. We just ignore it and continue.
}
} }
} }
...@@ -184,7 +193,8 @@ impl ApproxKvIndexer { ...@@ -184,7 +193,8 @@ impl ApproxKvIndexer {
runtime.block_on(async move { runtime.block_on(async move {
let mut trie = RadixTree::new(); let mut trie = RadixTree::new();
let mut timer_manager: TimerManager<TimerEntry> = TimerManager::new(ttl); // Use a reasonable threshold - can be made configurable if needed
let mut timer_manager: TimerManager<TimerEntry> = TimerManager::new(ttl, 50);
let mut event_id = 0; let mut event_id = 0;
loop { loop {
// Create a future that sleeps until the next expiration time. // Create a future that sleeps until the next expiration time.
...@@ -398,7 +408,7 @@ mod tests { ...@@ -398,7 +408,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_timer_manager_expiry() { async fn test_timer_manager_expiry() {
const TTL: Duration = Duration::from_millis(50); const TTL: Duration = Duration::from_millis(50);
let mut tm: TimerManager<u32> = TimerManager::new(TTL); let mut tm: TimerManager<u32> = TimerManager::new(TTL, 50);
tm.insert(vec![1, 2, 3]); tm.insert(vec![1, 2, 3]);
assert!(tm.get_expiry(&1).is_some()); assert!(tm.get_expiry(&1).is_some());
...@@ -419,7 +429,7 @@ mod tests { ...@@ -419,7 +429,7 @@ mod tests {
async fn test_timer_manager_update_resets_ttl() { async fn test_timer_manager_update_resets_ttl() {
// Validate that reinserting an existing key extends its TTL and prevents premature expiry. // Validate that reinserting an existing key extends its TTL and prevents premature expiry.
const TTL: Duration = Duration::from_millis(50); const TTL: Duration = Duration::from_millis(50);
let mut tm: TimerManager<u32> = TimerManager::new(TTL); let mut tm: TimerManager<u32> = TimerManager::new(TTL, 50);
// Initial insert and capture the original expiry. // Initial insert and capture the original expiry.
tm.insert(vec![42]); tm.insert(vec![42]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment