Unverified Commit b5e762b2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: remove stale workers on snapshot + some refactoring (#3589)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent f0065bb4
......@@ -416,13 +416,14 @@ impl KvIndexer {
.into();
// Use the shared start_kv_router_background function for event consumption
// Pass None for snapshot_tx to skip snapshot handling in Python bindings
// Pass None for snapshot_tx and get_workers_tx to skip snapshot handling in Python bindings
llm_rs::kv_router::subscriber::start_kv_router_background(
component.inner.clone(),
consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
inner.event_sender(),
inner.remove_worker_sender(),
None,
None,
cancellation_token,
None,
true,
......
......@@ -292,6 +292,9 @@ impl KvRouter {
consumer_uuid,
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.get_workers_sender()),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.snapshot_event_sender()),
......
......@@ -184,6 +184,8 @@ impl ApproxKvIndexer {
let (match_tx, mut match_rx) = mpsc::channel::<MatchRequest>(2048);
let (route_tx, mut route_rx) = mpsc::channel::<RouterResult>(2048);
let (remove_worker_tx, mut remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (_get_workers_tx, mut get_workers_rx) =
mpsc::channel::<super::indexer::GetWorkersRequest>(16);
let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16);
let cancel_clone = token.clone();
let task = std::thread::spawn(move || {
......@@ -217,6 +219,11 @@ impl ApproxKvIndexer {
trie.remove_worker(worker);
}
Some(get_workers_req) = get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
}
Some(result) = route_rx.recv() => {
let hashes = result.local_hashes.iter().zip(result.sequence_hashes.iter());
......
......@@ -469,6 +469,11 @@ impl RadixTree {
}
}
/// Get all worker IDs currently tracked in the radix tree.
pub fn get_workers(&self) -> Vec<WorkerId> {
self.lookup.keys().copied().collect()
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
......@@ -704,6 +709,12 @@ pub struct DumpRequest {
pub resp: oneshot::Sender<Vec<RouterEvent>>,
}
/// A request to get all workers currently tracked
pub struct GetWorkersRequest {
/// Channel to send the worker IDs
pub resp: oneshot::Sender<Vec<WorkerId>>,
}
#[async_trait]
pub trait KvIndexerInterface {
/// Find matches for a given sequence of `LocalBlockHash`es.
......@@ -769,6 +780,8 @@ pub struct KvIndexer {
match_tx: mpsc::Sender<MatchRequest>,
/// A sender for remove worker requests.
remove_worker_tx: mpsc::Sender<WorkerId>,
/// A sender for get workers requests.
get_workers_tx: mpsc::Sender<GetWorkersRequest>,
/// A sender for dump requests.
dump_tx: mpsc::Sender<DumpRequest>,
/// A handle to the background task managing the KV store.
......@@ -797,6 +810,7 @@ impl KvIndexer {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (get_workers_tx, get_workers_rx) = mpsc::channel::<GetWorkersRequest>(16);
let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
let cancel_clone = token.clone();
......@@ -812,6 +826,7 @@ impl KvIndexer {
let mut match_rx = match_rx;
let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx;
let mut get_workers_rx = get_workers_rx;
let mut dump_rx = dump_rx;
let mut trie = RadixTree::new_with_frequency(expiration_duration);
loop {
......@@ -827,6 +842,11 @@ impl KvIndexer {
trie.remove_worker(worker);
}
Some(get_workers_req) = get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
}
Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event);
......@@ -857,6 +877,7 @@ impl KvIndexer {
event_tx,
match_tx,
remove_worker_tx,
get_workers_tx,
dump_tx,
task: once,
kv_block_size,
......@@ -901,6 +922,15 @@ impl KvIndexer {
pub fn remove_worker_sender(&self) -> mpsc::Sender<WorkerId> {
self.remove_worker_tx.clone()
}
/// Get a sender for get workers requests.
///
/// ### Returns
///
/// A `mpsc::Sender` for `GetWorkersRequest`s.
pub fn get_workers_sender(&self) -> mpsc::Sender<GetWorkersRequest> {
self.get_workers_tx.clone()
}
}
#[async_trait]
......@@ -1039,6 +1069,7 @@ impl KvIndexerSharded {
let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new();
let mut get_workers_tx = Vec::new();
let mut dump_tx = Vec::new(); // Add dump channels
let mut tasks = Vec::new();
......@@ -1048,6 +1079,8 @@ impl KvIndexerSharded {
let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
mpsc::channel::<WorkerId>(16);
let (shard_get_workers_tx, mut shard_get_workers_rx) =
mpsc::channel::<GetWorkersRequest>(16);
let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16); // Add dump channel
let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
let cancel = token.clone();
......@@ -1055,6 +1088,7 @@ impl KvIndexerSharded {
event_tx.push(shard_event_tx);
remove_worker_tx.push(shard_remove_worker_tx);
get_workers_tx.push(shard_get_workers_tx);
dump_tx.push(shard_dump_tx); // Store dump sender
let runtime = tokio::runtime::Builder::new_current_thread()
......@@ -1078,6 +1112,11 @@ impl KvIndexerSharded {
trie.remove_worker(worker);
}
Some(get_workers_req) = shard_get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
}
Some(event) = shard_event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event);
......
......@@ -23,7 +23,7 @@ use crate::{
kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK,
ROUTER_SNAPSHOT_LOCK,
indexer::{DumpRequest, RouterEvent},
indexer::{DumpRequest, GetWorkersRequest, RouterEvent, WorkerId},
},
};
......@@ -32,17 +32,18 @@ use crate::{
struct SnapshotResources {
nats_client: dynamo_runtime::transports::nats::Client,
bucket_name: String,
etcd_client: EtcdClient,
lock_name: String,
instances_rx: tokio::sync::watch::Receiver<Vec<dynamo_runtime::component::Instance>>,
get_workers_tx: mpsc::Sender<GetWorkersRequest>,
snapshot_tx: mpsc::Sender<DumpRequest>,
}
impl SnapshotResources {
/// Try to acquire distributed lock for snapshot operations
/// Returns Some(lock_response) if lock acquired, None if another instance holds it
async fn lock(&self) -> Option<etcd_client::LockResponse> {
match self
.etcd_client
.lock(self.lock_name.clone(), Some(self.etcd_client.lease_id()))
async fn lock(&self, etcd_client: &EtcdClient) -> Option<etcd_client::LockResponse> {
match etcd_client
.lock(self.lock_name.clone(), Some(etcd_client.lease_id()))
.await
{
Ok(response) => {
......@@ -60,11 +61,101 @@ impl SnapshotResources {
}
/// Release the distributed lock
async fn unlock(&self, lock_response: etcd_client::LockResponse) {
if let Err(e) = self.etcd_client.unlock(lock_response.key()).await {
async fn unlock(&self, etcd_client: &EtcdClient, lock_response: etcd_client::LockResponse) {
if let Err(e) = etcd_client.unlock(lock_response.key()).await {
tracing::warn!("Failed to release snapshot lock: {e:?}");
}
}
/// Perform snapshot upload and purge operations
async fn purge_then_snapshot(
&self,
nats_queue: &mut NatsQueue,
remove_worker_tx: &mpsc::Sender<WorkerId>,
) -> anyhow::Result<()> {
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
let start_time = std::time::Instant::now();
// Clean up stale workers before snapshot
// Get current worker IDs from instances_rx
let current_instances = self.instances_rx.borrow().clone();
let current_worker_ids: std::collections::HashSet<i64> = current_instances
.iter()
.map(|instance| instance.instance_id)
.collect();
// Get worker IDs from the indexer
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let get_workers_req = GetWorkersRequest { resp: resp_tx };
if let Err(e) = self.get_workers_tx.send(get_workers_req).await {
tracing::warn!("Failed to send get_workers request during snapshot: {e:?}");
} else {
match resp_rx.await {
Ok(indexer_worker_ids) => {
// Find workers in indexer but not in current instances
for worker_id in indexer_worker_ids {
if !current_worker_ids.contains(&worker_id) {
tracing::info!(
"Removing stale worker {} from indexer during snapshot",
worker_id
);
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!(
"Failed to send remove_worker for stale worker {}: {e:?}",
worker_id
);
}
}
}
}
Err(e) => {
tracing::warn!("Failed to receive worker IDs from indexer: {e:?}");
}
}
}
// First, purge acknowledged messages from the stream
nats_queue.purge_acknowledged().await?;
// Now request a snapshot from the indexer (which reflects the post-purge state)
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
self.snapshot_tx
.send(dump_req)
.await
.map_err(|e| anyhow::anyhow!("Failed to send dump request: {e:?}"))?;
// Wait for the dump response
let events = resp_rx
.await
.map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?;
// Upload the snapshot to NATS object store
let url = url::Url::parse(&format!(
"nats://{}/{}/{RADIX_STATE_FILE}",
self.nats_client.addr(),
self.bucket_name
))?;
self.nats_client
.object_store_upload_data(&events, &url)
.await
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
tracing::info!(
"Successfully performed snapshot of radix tree with {} events to bucket {} in {}ms",
events.len(),
self.bucket_name,
start_time.elapsed().as_millis()
);
Ok(())
}
}
/// Start a unified background task for event consumption and optional snapshot management
......@@ -73,8 +164,9 @@ pub async fn start_kv_router_background(
component: Component,
consumer_uuid: String,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<crate::kv_router::indexer::WorkerId>,
snapshot_tx: Option<mpsc::Sender<DumpRequest>>,
remove_worker_tx: mpsc::Sender<WorkerId>,
maybe_get_workers_tx: Option<mpsc::Sender<GetWorkersRequest>>,
maybe_snapshot_tx: Option<mpsc::Sender<DumpRequest>>,
cancellation_token: CancellationToken,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
......@@ -168,15 +260,30 @@ pub async fn start_kv_router_background(
.await?
.dissolve();
// Only set up snapshot-related resources if snapshot_tx is provided and threshold is set
let snapshot_resources = if snapshot_tx.is_some() && router_snapshot_threshold.is_some() {
// Get instances_rx for tracking current workers
let client = generate_endpoint.client().await?;
let instances_rx = match client.instance_source.as_ref() {
dynamo_runtime::component::InstanceSource::Dynamic(rx) => rx.clone(),
dynamo_runtime::component::InstanceSource::Static => {
anyhow::bail!("Expected dynamic instance source for KV routing");
}
};
// Only set up snapshot-related resources if snapshot_tx, get_workers_tx, and threshold are provided
let snapshot_resources = if let (Some(get_workers_tx), Some(snapshot_tx), Some(_)) = (
maybe_get_workers_tx,
maybe_snapshot_tx,
router_snapshot_threshold,
) {
let lock_name = format!("{}/{}", ROUTER_SNAPSHOT_LOCK, component.subject());
Some(SnapshotResources {
nats_client,
bucket_name,
etcd_client: etcd_client.clone(),
lock_name,
instances_rx,
get_workers_tx,
snapshot_tx,
})
} else {
None
......@@ -256,9 +363,9 @@ pub async fn start_kv_router_background(
}
}
// Handle periodic stream checking and purging (only if snapshot_tx is provided)
// Handle periodic stream checking and purging (only if snapshot_resources is provided)
_ = check_interval.tick() => {
let Some((snapshot_tx, resources)) = snapshot_tx.as_ref().zip(snapshot_resources.as_ref()) else {
let Some(resources) = snapshot_resources.as_ref() else {
continue;
};
......@@ -277,22 +384,21 @@ pub async fn start_kv_router_background(
tracing::info!("Stream has {message_count} messages, attempting to acquire lock for purge and snapshot");
// Try to acquire distributed lock
let Some(lock_response) = resources.lock().await else {
let Some(lock_response) = resources.lock(&etcd_client).await else {
continue;
};
// Perform snapshot upload and purge
match purge_then_snapshot(
match resources.purge_then_snapshot(
&mut nats_queue,
snapshot_tx,
resources
&remove_worker_tx,
).await {
Ok(_) => tracing::info!("Successfully performed purge and snapshot"),
Err(e) => tracing::error!("Failed to perform purge and snapshot: {e:?}"),
}
// Release the lock
resources.unlock(lock_response).await;
resources.unlock(&etcd_client, lock_response).await;
}
// Handle router deletion events
......@@ -405,55 +511,3 @@ async fn cleanup_orphaned_consumers(
}
}
}
/// Perform snapshot upload and purge operations
async fn purge_then_snapshot(
nats_queue: &mut NatsQueue,
snapshot_tx: &mpsc::Sender<DumpRequest>,
resources: &SnapshotResources,
) -> anyhow::Result<()> {
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
let start_time = std::time::Instant::now();
// First, purge acknowledged messages from the stream
nats_queue.purge_acknowledged().await?;
// Now request a snapshot from the indexer (which reflects the post-purge state)
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
snapshot_tx
.send(dump_req)
.await
.map_err(|e| anyhow::anyhow!("Failed to send dump request: {e:?}"))?;
// Wait for the dump response
let events = resp_rx
.await
.map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?;
// Upload the snapshot to NATS object store
let url = url::Url::parse(&format!(
"nats://{}/{}/{RADIX_STATE_FILE}",
resources.nats_client.addr(),
resources.bucket_name
))?;
resources
.nats_client
.object_store_upload_data(&events, &url)
.await
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
tracing::info!(
"Successfully performed snapshot of radix tree with {} events to bucket {} in {}ms",
events.len(),
resources.bucket_name,
start_time.elapsed().as_millis()
);
Ok(())
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment