Unverified Commit b5e762b2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: remove stale workers on snapshot + some refactoring (#3589)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent f0065bb4
...@@ -416,13 +416,14 @@ impl KvIndexer { ...@@ -416,13 +416,14 @@ impl KvIndexer {
.into(); .into();
// Use the shared start_kv_router_background function for event consumption // Use the shared start_kv_router_background function for event consumption
// Pass None for snapshot_tx to skip snapshot handling in Python bindings // Pass None for snapshot_tx and get_workers_tx to skip snapshot handling in Python bindings
llm_rs::kv_router::subscriber::start_kv_router_background( llm_rs::kv_router::subscriber::start_kv_router_background(
component.inner.clone(), component.inner.clone(),
consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
inner.event_sender(), inner.event_sender(),
inner.remove_worker_sender(), inner.remove_worker_sender(),
None, None,
None,
cancellation_token, cancellation_token,
None, None,
true, true,
......
...@@ -292,6 +292,9 @@ impl KvRouter { ...@@ -292,6 +292,9 @@ impl KvRouter {
consumer_uuid, consumer_uuid,
kv_indexer.event_sender(), kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(), kv_indexer.remove_worker_sender(),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.get_workers_sender()),
kv_router_config kv_router_config
.router_snapshot_threshold .router_snapshot_threshold
.map(|_| kv_indexer.snapshot_event_sender()), .map(|_| kv_indexer.snapshot_event_sender()),
......
...@@ -184,6 +184,8 @@ impl ApproxKvIndexer { ...@@ -184,6 +184,8 @@ impl ApproxKvIndexer {
let (match_tx, mut match_rx) = mpsc::channel::<MatchRequest>(2048); let (match_tx, mut match_rx) = mpsc::channel::<MatchRequest>(2048);
let (route_tx, mut route_rx) = mpsc::channel::<RouterResult>(2048); let (route_tx, mut route_rx) = mpsc::channel::<RouterResult>(2048);
let (remove_worker_tx, mut remove_worker_rx) = mpsc::channel::<WorkerId>(16); let (remove_worker_tx, mut remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (_get_workers_tx, mut get_workers_rx) =
mpsc::channel::<super::indexer::GetWorkersRequest>(16);
let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16); let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16);
let cancel_clone = token.clone(); let cancel_clone = token.clone();
let task = std::thread::spawn(move || { let task = std::thread::spawn(move || {
...@@ -217,6 +219,11 @@ impl ApproxKvIndexer { ...@@ -217,6 +219,11 @@ impl ApproxKvIndexer {
trie.remove_worker(worker); trie.remove_worker(worker);
} }
Some(get_workers_req) = get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
}
Some(result) = route_rx.recv() => { Some(result) = route_rx.recv() => {
let hashes = result.local_hashes.iter().zip(result.sequence_hashes.iter()); let hashes = result.local_hashes.iter().zip(result.sequence_hashes.iter());
......
...@@ -469,6 +469,11 @@ impl RadixTree { ...@@ -469,6 +469,11 @@ impl RadixTree {
} }
} }
/// Get all worker IDs currently tracked in the radix tree.
pub fn get_workers(&self) -> Vec<WorkerId> {
self.lookup.keys().copied().collect()
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree. /// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique, /// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost. /// though the exact event ordering will be lost.
...@@ -704,6 +709,12 @@ pub struct DumpRequest { ...@@ -704,6 +709,12 @@ pub struct DumpRequest {
pub resp: oneshot::Sender<Vec<RouterEvent>>, pub resp: oneshot::Sender<Vec<RouterEvent>>,
} }
/// A request to get all workers currently tracked
pub struct GetWorkersRequest {
/// Channel to send the worker IDs
pub resp: oneshot::Sender<Vec<WorkerId>>,
}
#[async_trait] #[async_trait]
pub trait KvIndexerInterface { pub trait KvIndexerInterface {
/// Find matches for a given sequence of `LocalBlockHash`es. /// Find matches for a given sequence of `LocalBlockHash`es.
...@@ -769,6 +780,8 @@ pub struct KvIndexer { ...@@ -769,6 +780,8 @@ pub struct KvIndexer {
match_tx: mpsc::Sender<MatchRequest>, match_tx: mpsc::Sender<MatchRequest>,
/// A sender for remove worker requests. /// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>, remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for get workers requests.
get_workers_tx: mpsc::Sender<GetWorkersRequest>,
/// A sender for dump requests. /// A sender for dump requests.
dump_tx: mpsc::Sender<DumpRequest>, dump_tx: mpsc::Sender<DumpRequest>,
/// A handle to the background task managing the KV store. /// A handle to the background task managing the KV store.
...@@ -797,6 +810,7 @@ impl KvIndexer { ...@@ -797,6 +810,7 @@ impl KvIndexer {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048); let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128); let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16); let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (get_workers_tx, get_workers_rx) = mpsc::channel::<GetWorkersRequest>(16);
let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16); let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
let cancel_clone = token.clone(); let cancel_clone = token.clone();
...@@ -812,6 +826,7 @@ impl KvIndexer { ...@@ -812,6 +826,7 @@ impl KvIndexer {
let mut match_rx = match_rx; let mut match_rx = match_rx;
let mut event_rx = event_rx; let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx; let mut remove_worker_rx = remove_worker_rx;
let mut get_workers_rx = get_workers_rx;
let mut dump_rx = dump_rx; let mut dump_rx = dump_rx;
let mut trie = RadixTree::new_with_frequency(expiration_duration); let mut trie = RadixTree::new_with_frequency(expiration_duration);
loop { loop {
...@@ -827,6 +842,11 @@ impl KvIndexer { ...@@ -827,6 +842,11 @@ impl KvIndexer {
trie.remove_worker(worker); trie.remove_worker(worker);
} }
Some(get_workers_req) = get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
}
Some(event) = event_rx.recv() => { Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data); let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event); let result = trie.apply_event(event);
...@@ -857,6 +877,7 @@ impl KvIndexer { ...@@ -857,6 +877,7 @@ impl KvIndexer {
event_tx, event_tx,
match_tx, match_tx,
remove_worker_tx, remove_worker_tx,
get_workers_tx,
dump_tx, dump_tx,
task: once, task: once,
kv_block_size, kv_block_size,
...@@ -901,6 +922,15 @@ impl KvIndexer { ...@@ -901,6 +922,15 @@ impl KvIndexer {
pub fn remove_worker_sender(&self) -> mpsc::Sender<WorkerId> { pub fn remove_worker_sender(&self) -> mpsc::Sender<WorkerId> {
self.remove_worker_tx.clone() self.remove_worker_tx.clone()
} }
/// Get a sender for get workers requests.
///
/// ### Returns
///
/// A `mpsc::Sender` for `GetWorkersRequest`s.
pub fn get_workers_sender(&self) -> mpsc::Sender<GetWorkersRequest> {
self.get_workers_tx.clone()
}
} }
#[async_trait] #[async_trait]
...@@ -1039,6 +1069,7 @@ impl KvIndexerSharded { ...@@ -1039,6 +1069,7 @@ impl KvIndexerSharded {
let mut event_tx = Vec::new(); let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new(); let mut remove_worker_tx = Vec::new();
let mut get_workers_tx = Vec::new();
let mut dump_tx = Vec::new(); // Add dump channels let mut dump_tx = Vec::new(); // Add dump channels
let mut tasks = Vec::new(); let mut tasks = Vec::new();
...@@ -1048,6 +1079,8 @@ impl KvIndexerSharded { ...@@ -1048,6 +1079,8 @@ impl KvIndexerSharded {
let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048); let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
let (shard_remove_worker_tx, mut shard_remove_worker_rx) = let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
mpsc::channel::<WorkerId>(16); mpsc::channel::<WorkerId>(16);
let (shard_get_workers_tx, mut shard_get_workers_rx) =
mpsc::channel::<GetWorkersRequest>(16);
let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16); // Add dump channel let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16); // Add dump channel
let mut shard_broadcast_rx = request_broadcast_tx.subscribe(); let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
let cancel = token.clone(); let cancel = token.clone();
...@@ -1055,6 +1088,7 @@ impl KvIndexerSharded { ...@@ -1055,6 +1088,7 @@ impl KvIndexerSharded {
event_tx.push(shard_event_tx); event_tx.push(shard_event_tx);
remove_worker_tx.push(shard_remove_worker_tx); remove_worker_tx.push(shard_remove_worker_tx);
get_workers_tx.push(shard_get_workers_tx);
dump_tx.push(shard_dump_tx); // Store dump sender dump_tx.push(shard_dump_tx); // Store dump sender
let runtime = tokio::runtime::Builder::new_current_thread() let runtime = tokio::runtime::Builder::new_current_thread()
...@@ -1078,6 +1112,11 @@ impl KvIndexerSharded { ...@@ -1078,6 +1112,11 @@ impl KvIndexerSharded {
trie.remove_worker(worker); trie.remove_worker(worker);
} }
Some(get_workers_req) = shard_get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
}
Some(event) = shard_event_rx.recv() => { Some(event) = shard_event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data); let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event); let result = trie.apply_event(event);
......
...@@ -23,7 +23,7 @@ use crate::{ ...@@ -23,7 +23,7 @@ use crate::{
kv_router::{ kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK, KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK,
ROUTER_SNAPSHOT_LOCK, ROUTER_SNAPSHOT_LOCK,
indexer::{DumpRequest, RouterEvent}, indexer::{DumpRequest, GetWorkersRequest, RouterEvent, WorkerId},
}, },
}; };
...@@ -32,17 +32,18 @@ use crate::{ ...@@ -32,17 +32,18 @@ use crate::{
struct SnapshotResources { struct SnapshotResources {
nats_client: dynamo_runtime::transports::nats::Client, nats_client: dynamo_runtime::transports::nats::Client,
bucket_name: String, bucket_name: String,
etcd_client: EtcdClient,
lock_name: String, lock_name: String,
instances_rx: tokio::sync::watch::Receiver<Vec<dynamo_runtime::component::Instance>>,
get_workers_tx: mpsc::Sender<GetWorkersRequest>,
snapshot_tx: mpsc::Sender<DumpRequest>,
} }
impl SnapshotResources { impl SnapshotResources {
/// Try to acquire distributed lock for snapshot operations /// Try to acquire distributed lock for snapshot operations
/// Returns Some(lock_response) if lock acquired, None if another instance holds it /// Returns Some(lock_response) if lock acquired, None if another instance holds it
async fn lock(&self) -> Option<etcd_client::LockResponse> { async fn lock(&self, etcd_client: &EtcdClient) -> Option<etcd_client::LockResponse> {
match self match etcd_client
.etcd_client .lock(self.lock_name.clone(), Some(etcd_client.lease_id()))
.lock(self.lock_name.clone(), Some(self.etcd_client.lease_id()))
.await .await
{ {
Ok(response) => { Ok(response) => {
...@@ -60,11 +61,101 @@ impl SnapshotResources { ...@@ -60,11 +61,101 @@ impl SnapshotResources {
} }
/// Release the distributed lock /// Release the distributed lock
async fn unlock(&self, lock_response: etcd_client::LockResponse) { async fn unlock(&self, etcd_client: &EtcdClient, lock_response: etcd_client::LockResponse) {
if let Err(e) = self.etcd_client.unlock(lock_response.key()).await { if let Err(e) = etcd_client.unlock(lock_response.key()).await {
tracing::warn!("Failed to release snapshot lock: {e:?}"); tracing::warn!("Failed to release snapshot lock: {e:?}");
} }
} }
/// Perform snapshot upload and purge operations
async fn purge_then_snapshot(
&self,
nats_queue: &mut NatsQueue,
remove_worker_tx: &mpsc::Sender<WorkerId>,
) -> anyhow::Result<()> {
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
let start_time = std::time::Instant::now();
// Clean up stale workers before snapshot
// Get current worker IDs from instances_rx
let current_instances = self.instances_rx.borrow().clone();
let current_worker_ids: std::collections::HashSet<i64> = current_instances
.iter()
.map(|instance| instance.instance_id)
.collect();
// Get worker IDs from the indexer
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let get_workers_req = GetWorkersRequest { resp: resp_tx };
if let Err(e) = self.get_workers_tx.send(get_workers_req).await {
tracing::warn!("Failed to send get_workers request during snapshot: {e:?}");
} else {
match resp_rx.await {
Ok(indexer_worker_ids) => {
// Find workers in indexer but not in current instances
for worker_id in indexer_worker_ids {
if !current_worker_ids.contains(&worker_id) {
tracing::info!(
"Removing stale worker {} from indexer during snapshot",
worker_id
);
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!(
"Failed to send remove_worker for stale worker {}: {e:?}",
worker_id
);
}
}
}
}
Err(e) => {
tracing::warn!("Failed to receive worker IDs from indexer: {e:?}");
}
}
}
// First, purge acknowledged messages from the stream
nats_queue.purge_acknowledged().await?;
// Now request a snapshot from the indexer (which reflects the post-purge state)
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
self.snapshot_tx
.send(dump_req)
.await
.map_err(|e| anyhow::anyhow!("Failed to send dump request: {e:?}"))?;
// Wait for the dump response
let events = resp_rx
.await
.map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?;
// Upload the snapshot to NATS object store
let url = url::Url::parse(&format!(
"nats://{}/{}/{RADIX_STATE_FILE}",
self.nats_client.addr(),
self.bucket_name
))?;
self.nats_client
.object_store_upload_data(&events, &url)
.await
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
tracing::info!(
"Successfully performed snapshot of radix tree with {} events to bucket {} in {}ms",
events.len(),
self.bucket_name,
start_time.elapsed().as_millis()
);
Ok(())
}
} }
/// Start a unified background task for event consumption and optional snapshot management /// Start a unified background task for event consumption and optional snapshot management
...@@ -73,8 +164,9 @@ pub async fn start_kv_router_background( ...@@ -73,8 +164,9 @@ pub async fn start_kv_router_background(
component: Component, component: Component,
consumer_uuid: String, consumer_uuid: String,
kv_events_tx: mpsc::Sender<RouterEvent>, kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<crate::kv_router::indexer::WorkerId>, remove_worker_tx: mpsc::Sender<WorkerId>,
snapshot_tx: Option<mpsc::Sender<DumpRequest>>, maybe_get_workers_tx: Option<mpsc::Sender<GetWorkersRequest>>,
maybe_snapshot_tx: Option<mpsc::Sender<DumpRequest>>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
router_snapshot_threshold: Option<u32>, router_snapshot_threshold: Option<u32>,
router_reset_states: bool, router_reset_states: bool,
...@@ -168,15 +260,30 @@ pub async fn start_kv_router_background( ...@@ -168,15 +260,30 @@ pub async fn start_kv_router_background(
.await? .await?
.dissolve(); .dissolve();
// Only set up snapshot-related resources if snapshot_tx is provided and threshold is set // Get instances_rx for tracking current workers
let snapshot_resources = if snapshot_tx.is_some() && router_snapshot_threshold.is_some() { let client = generate_endpoint.client().await?;
let instances_rx = match client.instance_source.as_ref() {
dynamo_runtime::component::InstanceSource::Dynamic(rx) => rx.clone(),
dynamo_runtime::component::InstanceSource::Static => {
anyhow::bail!("Expected dynamic instance source for KV routing");
}
};
// Only set up snapshot-related resources if snapshot_tx, get_workers_tx, and threshold are provided
let snapshot_resources = if let (Some(get_workers_tx), Some(snapshot_tx), Some(_)) = (
maybe_get_workers_tx,
maybe_snapshot_tx,
router_snapshot_threshold,
) {
let lock_name = format!("{}/{}", ROUTER_SNAPSHOT_LOCK, component.subject()); let lock_name = format!("{}/{}", ROUTER_SNAPSHOT_LOCK, component.subject());
Some(SnapshotResources { Some(SnapshotResources {
nats_client, nats_client,
bucket_name, bucket_name,
etcd_client: etcd_client.clone(),
lock_name, lock_name,
instances_rx,
get_workers_tx,
snapshot_tx,
}) })
} else { } else {
None None
...@@ -256,9 +363,9 @@ pub async fn start_kv_router_background( ...@@ -256,9 +363,9 @@ pub async fn start_kv_router_background(
} }
} }
// Handle periodic stream checking and purging (only if snapshot_tx is provided) // Handle periodic stream checking and purging (only if snapshot_resources is provided)
_ = check_interval.tick() => { _ = check_interval.tick() => {
let Some((snapshot_tx, resources)) = snapshot_tx.as_ref().zip(snapshot_resources.as_ref()) else { let Some(resources) = snapshot_resources.as_ref() else {
continue; continue;
}; };
...@@ -277,22 +384,21 @@ pub async fn start_kv_router_background( ...@@ -277,22 +384,21 @@ pub async fn start_kv_router_background(
tracing::info!("Stream has {message_count} messages, attempting to acquire lock for purge and snapshot"); tracing::info!("Stream has {message_count} messages, attempting to acquire lock for purge and snapshot");
// Try to acquire distributed lock // Try to acquire distributed lock
let Some(lock_response) = resources.lock().await else { let Some(lock_response) = resources.lock(&etcd_client).await else {
continue; continue;
}; };
// Perform snapshot upload and purge // Perform snapshot upload and purge
match purge_then_snapshot( match resources.purge_then_snapshot(
&mut nats_queue, &mut nats_queue,
snapshot_tx, &remove_worker_tx,
resources
).await { ).await {
Ok(_) => tracing::info!("Successfully performed purge and snapshot"), Ok(_) => tracing::info!("Successfully performed purge and snapshot"),
Err(e) => tracing::error!("Failed to perform purge and snapshot: {e:?}"), Err(e) => tracing::error!("Failed to perform purge and snapshot: {e:?}"),
} }
// Release the lock // Release the lock
resources.unlock(lock_response).await; resources.unlock(&etcd_client, lock_response).await;
} }
// Handle router deletion events // Handle router deletion events
...@@ -405,55 +511,3 @@ async fn cleanup_orphaned_consumers( ...@@ -405,55 +511,3 @@ async fn cleanup_orphaned_consumers(
} }
} }
} }
/// Perform snapshot upload and purge operations
async fn purge_then_snapshot(
nats_queue: &mut NatsQueue,
snapshot_tx: &mpsc::Sender<DumpRequest>,
resources: &SnapshotResources,
) -> anyhow::Result<()> {
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
let start_time = std::time::Instant::now();
// First, purge acknowledged messages from the stream
nats_queue.purge_acknowledged().await?;
// Now request a snapshot from the indexer (which reflects the post-purge state)
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
snapshot_tx
.send(dump_req)
.await
.map_err(|e| anyhow::anyhow!("Failed to send dump request: {e:?}"))?;
// Wait for the dump response
let events = resp_rx
.await
.map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?;
// Upload the snapshot to NATS object store
let url = url::Url::parse(&format!(
"nats://{}/{}/{RADIX_STATE_FILE}",
resources.nats_client.addr(),
resources.bucket_name
))?;
resources
.nats_client
.object_store_upload_data(&events, &url)
.await
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
tracing::info!(
"Successfully performed snapshot of radix tree with {} events to bucket {} in {}ms",
events.len(),
resources.bucket_name,
start_time.elapsed().as_millis()
);
Ok(())
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment