Unverified Commit 2cf427ed authored by Waël Boukhobza's avatar Waël Boukhobza Committed by GitHub
Browse files

fix(router): change prune channel to be mpsc instead of watch, increase max size (#4351)


Signed-off-by: default avatarWael Boukhobza <wawa_wael@live.fr>
parent 84737baf
...@@ -727,7 +727,7 @@ impl ApproxKvIndexer { ...@@ -727,7 +727,7 @@ impl ApproxKvIndexer {
fn new(component: Component, kv_block_size: usize, ttl_secs: f64) -> PyResult<Self> { fn new(component: Component, kv_block_size: usize, ttl_secs: f64) -> PyResult<Self> {
let ttl = tokio::time::Duration::from_secs_f64(ttl_secs); let ttl = tokio::time::Duration::from_secs_f64(ttl_secs);
let prune_config = Some(llm_rs::kv_router::approx::PruneConfig { let prune_config = Some(llm_rs::kv_router::approx::PruneConfig {
max_tree_size: 2usize.pow(14), // 2** 14 = 16384 max_tree_size: 2usize.pow(20), // 2 ** 20 = 1048576
prune_target_ratio: 0.8, prune_target_ratio: 0.8,
}); });
let inner = Arc::new(llm_rs::kv_router::approx::ApproxKvIndexer::new( let inner = Arc::new(llm_rs::kv_router::approx::ApproxKvIndexer::new(
......
...@@ -265,7 +265,7 @@ impl KvRouter { ...@@ -265,7 +265,7 @@ impl KvRouter {
block_size, block_size,
Duration::from_secs(120), Duration::from_secs(120),
Some(PruneConfig { Some(PruneConfig {
max_tree_size: 2usize.pow(14), // 2** 14 = 16384 max_tree_size: 2usize.pow(20), // 2 ** 20 = 1048576
prune_target_ratio: 0.8, prune_target_ratio: 0.8,
}), }),
)) ))
......
...@@ -21,7 +21,7 @@ use std::cmp::Reverse; ...@@ -21,7 +21,7 @@ use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap}; use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash; use std::hash::Hash;
use std::sync::OnceLock; use std::sync::OnceLock;
use tokio::sync::{mpsc, oneshot, watch}; use tokio::sync::{mpsc, oneshot};
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -277,7 +277,7 @@ impl ApproxKvIndexer { ...@@ -277,7 +277,7 @@ impl ApproxKvIndexer {
let (_get_workers_tx, mut get_workers_rx) = let (_get_workers_tx, mut get_workers_rx) =
mpsc::channel::<super::indexer::GetWorkersRequest>(16); mpsc::channel::<super::indexer::GetWorkersRequest>(16);
let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16); let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16);
let (prune_tx, mut prune_rx) = watch::channel(false); let (prune_tx, mut prune_rx) = mpsc::channel::<()>(1);
let cancel_clone = token.clone(); let cancel_clone = token.clone();
let task = std::thread::spawn(move || { let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread // create a new tokio runtime which will only perform work on a single thread
...@@ -302,6 +302,8 @@ impl ApproxKvIndexer { ...@@ -302,6 +302,8 @@ impl ApproxKvIndexer {
}; };
tokio::select! { tokio::select! {
biased;
_ = cancel_clone.cancelled() => { _ = cancel_clone.cancelled() => {
tracing::debug!("Approximate Indexer progress loop shutting down"); tracing::debug!("Approximate Indexer progress loop shutting down");
return; return;
...@@ -316,6 +318,27 @@ impl ApproxKvIndexer { ...@@ -316,6 +318,27 @@ impl ApproxKvIndexer {
let _ = get_workers_req.resp.send(workers); let _ = get_workers_req.resp.send(workers);
} }
Some(_) = prune_rx.recv() => {
// The tree has exceeded the max tree size, so proceed with pruning.
if let Ok(pruned) = prune_manager.prune(trie.current_size()) {
pruned.iter().for_each(|p| {
event_id += 1;
let event = RouterEvent::new(
p.worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![p.key],
}),
dp_rank: p.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
});
}
}
Some(result) = route_rx.recv() => { Some(result) = route_rx.recv() => {
let hashes = result.local_hashes.iter().zip(result.sequence_hashes.iter()); let hashes = result.local_hashes.iter().zip(result.sequence_hashes.iter());
...@@ -353,9 +376,9 @@ impl ApproxKvIndexer { ...@@ -353,9 +376,9 @@ impl ApproxKvIndexer {
current_size, current_size,
prune_config.max_tree_size prune_config.max_tree_size
); );
// Send a signal to the pruning watcher to schedule pruning. // Send a signal to the pruning receiver to schedule pruning.
if let Err(e) = prune_tx.send(true) { if let Err(mpsc::error::TrySendError::Closed(_)) = prune_tx.try_send(()) {
tracing::error!("Failed to send prune schedule signal: {:?}", e); tracing::error!("Failed to send prune schedule signal, prune receiver is closed");
} }
} }
} }
...@@ -372,31 +395,6 @@ impl ApproxKvIndexer { ...@@ -372,31 +395,6 @@ impl ApproxKvIndexer {
request.resp.send(scores).unwrap(); request.resp.send(scores).unwrap();
} }
Ok(_) = prune_rx.changed() => {
// The tree has exceeded the max tree size, so proceed with pruning.
if let Ok(pruned) = prune_manager.prune(trie.current_size()) {
pruned.iter().for_each(|p| {
event_id += 1;
let event = RouterEvent::new(
p.worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![p.key],
}),
dp_rank: p.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
});
// Reset the pruning watcher to false to indicate that pruning is complete.
if let Err(e) = prune_tx.send(true) {
tracing::error!("Failed to send prune completion signal: {:?}", e);
}
}
}
_ = expiry_fut => { _ = expiry_fut => {
let expired = prune_manager.pop_expired(); let expired = prune_manager.pop_expired();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment