Unverified Commit 1d34af75 authored by jain-ria's avatar jain-ria Committed by GitHub
Browse files

feat: all blocks cleared event (#1279)

parent 7bb21ee7
...@@ -373,6 +373,9 @@ impl RadixTree { ...@@ -373,6 +373,9 @@ impl RadixTree {
worker_lookup.remove(&block); worker_lookup.remove(&block);
} }
} }
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker_id);
}
} }
} }
...@@ -383,6 +386,23 @@ impl RadixTree { ...@@ -383,6 +386,23 @@ impl RadixTree {
}); });
} }
} }
pub fn clear_all_blocks(&mut self, worker: WorkerId) {
// Check if the worker has any blocks to clear
if let Some(blocks) = self.lookup.get(&worker) {
let blocks_to_clear: Vec<_> = blocks.values().collect();
// Remove the worker from each block's workers set
blocks_to_clear.iter().for_each(|block| {
block.borrow_mut().workers.remove(&worker);
});
// Clear the worker's blocks
if let Some(worker_blocks) = self.lookup.get_mut(&worker) {
worker_blocks.clear();
}
}
}
} }
/// Scores representing the overlap of workers. /// Scores representing the overlap of workers.
...@@ -1180,6 +1200,88 @@ mod tests { ...@@ -1180,6 +1200,88 @@ mod tests {
assert!(result.len() == 1 && result[&worker_1] == 1); assert!(result.len() == 1 && result[&worker_1] == 1);
} }
#[test]
fn test_clear_all_blocks() {
let mut trie = RadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
assert!(trie
.find_matches(vec![LocalBlockHash(0)], false)
.scores
.is_empty());
// Test clearing an empty worker
trie.clear_all_blocks(worker_0);
assert!(!trie.lookup.contains_key(&worker_0));
// Test clearing a worker with shared blocks
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None));
trie.apply_event(create_store_event(worker_1, 0, vec![0, 2, 3], None));
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);
trie.clear_all_blocks(worker_0);
assert!(trie.lookup.contains_key(&worker_0));
assert!(trie.lookup.get(&worker_0).unwrap().is_empty());
let result = trie
.find_matches(vec![LocalBlockHash(0), LocalBlockHash(2)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 2);
let result = trie
.find_matches(
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 1);
// Test re-adding blocks after clearing worker
trie.apply_event(create_store_event(worker_0, 0, vec![4, 5], None));
let result = trie
.find_matches(vec![LocalBlockHash(4), LocalBlockHash(5)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_0], 2);
// Test multiple clears
trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_0);
assert!(trie.lookup.contains_key(&worker_0));
// Test clearing all workers
trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_1);
assert!(!trie.lookup.is_empty());
assert!(trie.lookup.get(&worker_0).unwrap().is_empty());
assert!(trie.lookup.get(&worker_1).unwrap().is_empty());
// Test clearing a worker that has been removed
trie.apply_event(create_store_event(worker_0, 0, vec![6], None));
trie.apply_event(create_store_event(worker_1, 0, vec![6], None));
trie.remove_worker(worker_0);
trie.clear_all_blocks(worker_0);
assert!(!trie.lookup.contains_key(&worker_0));
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 1);
// Test clearing a worker that doesn't exist
let worker_fake = 2;
assert!(!trie.lookup.contains_key(&worker_fake));
trie.clear_all_blocks(worker_fake);
assert!(!trie.lookup.contains_key(&worker_fake));
assert!(trie.lookup.contains_key(&worker_1));
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&worker_1], 1);
}
#[test] #[test]
fn test_early_stopping() { fn test_early_stopping() {
setup(); setup();
......
...@@ -107,10 +107,9 @@ pub struct KvCacheEvent { ...@@ -107,10 +107,9 @@ pub struct KvCacheEvent {
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum KvCacheEventData { pub enum KvCacheEventData {
/// Data for a stored cache event.
Stored(KvCacheStoreData), Stored(KvCacheStoreData),
/// Data for a removed cache event.
Removed(KvCacheRemoveData), Removed(KvCacheRemoveData),
Cleared,
} }
/// Represents the data associated with a stored cache event. /// Represents the data associated with a stored cache event.
......
...@@ -318,11 +318,10 @@ async fn start_zmq_listener( ...@@ -318,11 +318,10 @@ async fn start_zmq_listener(
// For each of our events, convert them to [`KvCacheEvent`] and send to the event_processor. // For each of our events, convert them to [`KvCacheEvent`] and send to the event_processor.
for raw_event in batch.events.into_iter() { for raw_event in batch.events.into_iter() {
if let Some(event) = convert_event(raw_event, seq, kv_block_size, &warning_count) { let event = convert_event(raw_event, seq, kv_block_size, &warning_count);
if tx.send(event).is_err() { if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped"); tracing::warn!("Failed to send message to channel - receiver dropped");
return; return;
}
} }
} }
} }
...@@ -332,15 +331,13 @@ async fn start_zmq_listener( ...@@ -332,15 +331,13 @@ async fn start_zmq_listener(
} }
/// Convert a raw event coming from the ZMQ channel into the internal /// Convert a raw event coming from the ZMQ channel into the internal
/// [`KvCacheEvent`] representation used by the router. Returns `None` when the /// [`KvCacheEvent`] representation used by the router.
/// event cannot be represented with the current protocol (e.g., we ignore
/// `AllBlocksCleared` until a concrete format is defined).
fn convert_event( fn convert_event(
raw: RawKvEvent, raw: RawKvEvent,
event_id: u64, event_id: u64,
kv_block_size: usize, kv_block_size: usize,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
) -> Option<KvCacheEvent> { ) -> KvCacheEvent {
match raw { match raw {
RawKvEvent::BlockStored { RawKvEvent::BlockStored {
block_hashes, block_hashes,
...@@ -350,7 +347,7 @@ fn convert_event( ...@@ -350,7 +347,7 @@ fn convert_event(
lora_id, lora_id,
} => { } => {
let num_block_tokens = vec![block_size as u64; block_hashes.len()]; let num_block_tokens = vec![block_size as u64; block_hashes.len()];
Some(KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_block_hash.map(ExternalSequenceBlockHash::from), parent_hash: parent_block_hash.map(ExternalSequenceBlockHash::from),
...@@ -363,24 +360,24 @@ fn convert_event( ...@@ -363,24 +360,24 @@ fn convert_event(
warning_count, warning_count,
), ),
}), }),
}) }
} }
RawKvEvent::BlockRemoved { block_hashes } => { RawKvEvent::BlockRemoved { block_hashes } => {
let hashes = block_hashes let hashes = block_hashes
.into_iter() .into_iter()
.map(ExternalSequenceBlockHash::from) .map(ExternalSequenceBlockHash::from)
.collect(); .collect();
Some(KvCacheEvent { KvCacheEvent {
event_id, event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes, block_hashes: hashes,
}), }),
}) }
}
RawKvEvent::AllBlocksCleared => {
tracing::debug!("Received AllBlocksCleared event – currently ignored");
None
} }
RawKvEvent::AllBlocksCleared => KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
},
} }
} }
...@@ -614,7 +611,7 @@ mod test_event_processing { ...@@ -614,7 +611,7 @@ mod test_event_processing {
}; };
let out = convert_event(raw_evt, 42, kv_block_size, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 42, kv_block_size, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.unwrap().data, KvCacheEventData::Stored(_))); assert!(matches!(out.data, KvCacheEventData::Stored(_)));
} }
#[test] #[test]
...@@ -625,14 +622,15 @@ mod test_event_processing { ...@@ -625,14 +622,15 @@ mod test_event_processing {
}; };
let out = convert_event(raw_evt, 7, kv_block_size, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 7, kv_block_size, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.unwrap().data, KvCacheEventData::Removed(_))); assert!(matches!(out.data, KvCacheEventData::Removed(_)));
} }
#[test] #[test]
fn test_convert_event_all_blocks_cleared() { fn test_convert_event_all_blocks_cleared() {
let kv_block_size = 4; let kv_block_size = 4;
let raw_evt = RawKvEvent::AllBlocksCleared; let raw_evt = RawKvEvent::AllBlocksCleared;
assert!(convert_event(raw_evt, 1, kv_block_size, &Arc::new(AtomicU32::new(0))).is_none()); let out = convert_event(raw_evt, 1, kv_block_size, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Cleared));
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment