Unverified Commit 2cf427ed authored by Waël Boukhobza's avatar Waël Boukhobza Committed by GitHub
Browse files

fix(router): change prune channel to be mpsc instead of watch, increase max size (#4351)


Signed-off-by: default avatarWael Boukhobza <wawa_wael@live.fr>
parent 84737baf
......@@ -727,7 +727,7 @@ impl ApproxKvIndexer {
fn new(component: Component, kv_block_size: usize, ttl_secs: f64) -> PyResult<Self> {
let ttl = tokio::time::Duration::from_secs_f64(ttl_secs);
let prune_config = Some(llm_rs::kv_router::approx::PruneConfig {
max_tree_size: 2usize.pow(14), // 2** 14 = 16384
max_tree_size: 2usize.pow(20), // 2 ** 20 = 1048576
prune_target_ratio: 0.8,
});
let inner = Arc::new(llm_rs::kv_router::approx::ApproxKvIndexer::new(
......
......@@ -265,7 +265,7 @@ impl KvRouter {
block_size,
Duration::from_secs(120),
Some(PruneConfig {
max_tree_size: 2usize.pow(14), // 2** 14 = 16384
max_tree_size: 2usize.pow(20), // 2 ** 20 = 1048576
prune_target_ratio: 0.8,
}),
))
......
......@@ -21,7 +21,7 @@ use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
use std::sync::OnceLock;
use tokio::sync::{mpsc, oneshot, watch};
use tokio::sync::{mpsc, oneshot};
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
......@@ -277,7 +277,7 @@ impl ApproxKvIndexer {
let (_get_workers_tx, mut get_workers_rx) =
mpsc::channel::<super::indexer::GetWorkersRequest>(16);
let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16);
let (prune_tx, mut prune_rx) = watch::channel(false);
let (prune_tx, mut prune_rx) = mpsc::channel::<()>(1);
let cancel_clone = token.clone();
let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread
......@@ -302,6 +302,8 @@ impl ApproxKvIndexer {
};
tokio::select! {
biased;
_ = cancel_clone.cancelled() => {
tracing::debug!("Approximate Indexer progress loop shutting down");
return;
......@@ -316,6 +318,27 @@ impl ApproxKvIndexer {
let _ = get_workers_req.resp.send(workers);
}
Some(_) = prune_rx.recv() => {
// The tree has exceeded the max tree size, so proceed with pruning.
if let Ok(pruned) = prune_manager.prune(trie.current_size()) {
pruned.iter().for_each(|p| {
event_id += 1;
let event = RouterEvent::new(
p.worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![p.key],
}),
dp_rank: p.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
});
}
}
Some(result) = route_rx.recv() => {
let hashes = result.local_hashes.iter().zip(result.sequence_hashes.iter());
......@@ -353,9 +376,9 @@ impl ApproxKvIndexer {
current_size,
prune_config.max_tree_size
);
// Send a signal to the pruning watcher to schedule pruning.
if let Err(e) = prune_tx.send(true) {
tracing::error!("Failed to send prune schedule signal: {:?}", e);
// Send a signal to the pruning receiver to schedule pruning.
if let Err(mpsc::error::TrySendError::Closed(_)) = prune_tx.try_send(()) {
tracing::error!("Failed to send prune schedule signal, prune receiver is closed");
}
}
}
......@@ -372,31 +395,6 @@ impl ApproxKvIndexer {
request.resp.send(scores).unwrap();
}
Ok(_) = prune_rx.changed() => {
// The tree has exceeded the max tree size, so proceed with pruning.
if let Ok(pruned) = prune_manager.prune(trie.current_size()) {
pruned.iter().for_each(|p| {
event_id += 1;
let event = RouterEvent::new(
p.worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![p.key],
}),
dp_rank: p.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
});
// Reset the pruning watcher to false to indicate that pruning is complete.
if let Err(e) = prune_tx.send(true) {
tracing::error!("Failed to send prune completion signal: {:?}", e);
}
}
}
_ = expiry_fut => {
let expired = prune_manager.pop_expired();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment