Unverified Commit 227846f2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: use Rc to do reference counting in Router slot manager (#3545)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent b94ecd16
...@@ -801,57 +801,50 @@ impl KvIndexer { ...@@ -801,57 +801,50 @@ impl KvIndexer {
let cancel_clone = token.clone(); let cancel_clone = token.clone();
let task = std::thread::spawn(move || { let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread // Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_multi_thread() let runtime = tokio::runtime::Builder::new_current_thread()
.worker_threads(1) // Single-threaded environment
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
let local_set = tokio::task::LocalSet::new(); runtime.block_on(async move {
let cancel = cancel_clone;
runtime.block_on(local_set.run_until(async move { let mut match_rx = match_rx;
tokio::task::spawn_local(async move { let mut event_rx = event_rx;
let cancel = cancel_clone; let mut remove_worker_rx = remove_worker_rx;
let mut match_rx = match_rx; let mut dump_rx = dump_rx;
let mut event_rx = event_rx; let mut trie = RadixTree::new_with_frequency(expiration_duration);
let mut remove_worker_rx = remove_worker_rx; loop {
let mut dump_rx = dump_rx; tokio::select! {
let mut trie = RadixTree::new_with_frequency(expiration_duration); biased;
loop {
tokio::select! { _ = cancel.cancelled() => {
biased; tracing::debug!("KvCacheIndexer progress loop shutting down");
return;
_ = cancel.cancelled() => { }
tracing::debug!("KvCacheIndexer progress loop shutting down");
return;
}
Some(worker) = remove_worker_rx.recv() => { Some(worker) = remove_worker_rx.recv() => {
trie.remove_worker(worker); trie.remove_worker(worker);
} }
Some(event) = event_rx.recv() => { Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data); let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event); let result = trie.apply_event(event);
metrics.increment_event_applied(event_type, result); metrics.increment_event_applied(event_type, result);
} }
Some(dump_req) = dump_rx.recv() => { Some(dump_req) = dump_rx.recv() => {
let events = trie.dump_tree_as_events(); let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events); let _ = dump_req.resp.send(events);
} }
Some(req) = match_rx.recv() => { Some(req) = match_rx.recv() => {
let matches = trie.find_matches(req.sequence, req.early_exit); let matches = trie.find_matches(req.sequence, req.early_exit);
let _ = req.resp.send(matches); let _ = req.resp.send(matches);
}
} }
} }
}) }
.await });
.unwrap()
}));
tracing::debug!("KvCacheIndexer task completed"); tracing::debug!("KvCacheIndexer task completed");
}); });
...@@ -1058,54 +1051,47 @@ impl KvIndexerSharded { ...@@ -1058,54 +1051,47 @@ impl KvIndexerSharded {
remove_worker_tx.push(shard_remove_worker_tx); remove_worker_tx.push(shard_remove_worker_tx);
dump_tx.push(shard_dump_tx); // Store dump sender dump_tx.push(shard_dump_tx); // Store dump sender
let runtime = tokio::runtime::Builder::new_multi_thread() let runtime = tokio::runtime::Builder::new_current_thread()
.worker_threads(1)
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
tasks.push(std::thread::spawn(move || { tasks.push(std::thread::spawn(move || {
let local_set = tokio::task::LocalSet::new(); runtime.block_on(async move {
let mut trie = RadixTree::new_with_frequency(expiration_duration);
runtime.block_on(local_set.run_until(async move { loop {
tokio::task::spawn_local(async move { tokio::select! {
let mut trie = RadixTree::new_with_frequency(expiration_duration); biased;
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::trace!("KvCacheIndexer progress loop shutting down");
return;
}
Some(worker) = shard_remove_worker_rx.recv() => { _ = cancel.cancelled() => {
trie.remove_worker(worker); tracing::trace!("KvCacheIndexer progress loop shutting down");
} return;
}
Some(event) = shard_event_rx.recv() => { Some(worker) = shard_remove_worker_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data); trie.remove_worker(worker);
let result = trie.apply_event(event); }
metrics.increment_event_applied(event_type, result);
}
Some(dump_req) = shard_dump_rx.recv() => { Some(event) = shard_event_rx.recv() => {
let events = trie.dump_tree_as_events(); let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let _ = dump_req.resp.send(events); let result = trie.apply_event(event);
} metrics.increment_event_applied(event_type, result);
}
Ok(req) = shard_broadcast_rx.recv() => { Some(dump_req) = shard_dump_rx.recv() => {
let matches = trie.find_matches(req.sequence, req.early_exit); let events = trie.dump_tree_as_events();
if let Err(e) = req.resp.send(matches).await { let _ = dump_req.resp.send(events);
tracing::trace!("Failed to send match response: {:?}", e); }
}
Ok(req) = shard_broadcast_rx.recv() => {
let matches = trie.find_matches(req.sequence, req.early_exit);
if let Err(e) = req.resp.send(matches).await {
tracing::trace!("Failed to send match response: {:?}", e);
} }
} }
} }
}) }
.await });
.unwrap()
}));
tracing::debug!("KvCacheIndexer task completed"); tracing::debug!("KvCacheIndexer task completed");
})); }));
......
...@@ -33,6 +33,7 @@ use dynamo_runtime::traits::DistributedRuntimeProvider; ...@@ -33,6 +33,7 @@ use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt; use futures::StreamExt;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::rc::{Rc, Weak};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
...@@ -51,18 +52,15 @@ pub type RequestId = String; ...@@ -51,18 +52,15 @@ pub type RequestId = String;
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache /// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)] #[derive(Debug, Getters)]
pub struct ActiveSequences { pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<SequenceHash>>, active_seqs: HashMap<RequestId, Vec<(SequenceHash, Rc<()>)>>,
prefill_tokens: HashMap<RequestId, usize>, prefill_tokens: HashMap<RequestId, usize>,
unique_blocks: HashMap<SequenceHash, HashSet<RequestId>>, unique_blocks: HashMap<SequenceHash, Weak<()>>,
#[getter(copy)] #[getter(copy)]
block_size: usize, block_size: usize,
#[getter(copy)]
active_blocks: usize,
#[getter(copy)] #[getter(copy)]
active_tokens: usize, active_tokens: usize,
...@@ -84,39 +82,36 @@ impl ActiveSequences { ...@@ -84,39 +82,36 @@ impl ActiveSequences {
prefill_tokens: HashMap::new(), prefill_tokens: HashMap::new(),
unique_blocks: HashMap::new(), unique_blocks: HashMap::new(),
block_size, block_size,
active_blocks: 0,
active_tokens: 0, active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION, expiry_timer: Instant::now() + EXPIRY_DURATION,
expiry_requests: HashSet::new(), expiry_requests: HashSet::new(),
} }
} }
fn add_block(&mut self, request_id: RequestId, block: &SequenceHash) { fn touch_block(&mut self, block: &SequenceHash) -> Rc<()> {
let is_new_block = !self.unique_blocks.contains_key(block); if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
self.unique_blocks {
.entry(*block) return rc;
.or_default()
.insert(request_id.clone());
if is_new_block {
self.active_blocks += 1;
} }
}
fn remove_block(&mut self, request_id: &RequestId, block: &SequenceHash) { let rc = Rc::new(());
let Some(request_ids) = self.unique_blocks.get_mut(block) else { self.unique_blocks.insert(*block, Rc::downgrade(&rc));
return; rc
}; }
// Remove the unique block if no more requests using it fn try_remove_block(&mut self, block: &SequenceHash) {
request_ids.retain(|w| w != request_id); if let Some(weak) = self.unique_blocks.get(block)
if request_ids.is_empty() { && weak.strong_count() == 0
self.active_blocks -= 1; {
self.unique_blocks.remove(block); self.unique_blocks.remove(block);
} }
} }
pub fn active_blocks(&self) -> usize {
self.unique_blocks.len()
}
/// Add a new request with its initial tokens /// Add a new request with its initial tokens
/// Returns the set of expired request IDs that were removed during cleanup /// Returns the set of expired request IDs that were removed during cleanup
pub fn add_request( pub fn add_request(
...@@ -140,10 +135,12 @@ impl ActiveSequences { ...@@ -140,10 +135,12 @@ impl ActiveSequences {
self.active_tokens += prefill_tokens; self.active_tokens += prefill_tokens;
if let Some(sequence) = token_sequence { if let Some(sequence) = token_sequence {
for block in &sequence { let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence
self.add_block(request_id.clone(), block); .iter()
} .map(|block| (*block, self.touch_block(block)))
self.active_seqs.insert(request_id.clone(), sequence); .collect();
self.active_seqs
.insert(request_id.clone(), sequence_with_refs);
} else { } else {
// dummy empty sequence // dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new()); self.active_seqs.insert(request_id.clone(), Vec::new());
...@@ -174,9 +171,9 @@ impl ActiveSequences { ...@@ -174,9 +171,9 @@ impl ActiveSequences {
overlap: u32, overlap: u32,
) -> (usize, usize) { ) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence { let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks self.new_blocks(token_seq) + self.active_blocks()
} else { } else {
self.active_blocks self.active_blocks()
}; };
let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens; let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
(potential_blocks, potential_tokens) (potential_blocks, potential_tokens)
...@@ -193,7 +190,7 @@ impl ActiveSequences { ...@@ -193,7 +190,7 @@ impl ActiveSequences {
/// Return the total number of blocks that would be used if the token sequence was added /// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks /// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize { pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks self.new_blocks(token_sequence) + self.active_blocks()
} }
/// Free all blocks associated with a request /// Free all blocks associated with a request
...@@ -207,15 +204,17 @@ impl ActiveSequences { ...@@ -207,15 +204,17 @@ impl ActiveSequences {
Some(seq) => seq, Some(seq) => seq,
None => { None => {
tracing::warn!("Trying to free non-existent request {request_id}"); tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks; return self.active_blocks();
} }
}; };
for block in token_seq { // Drop each Rc reference, then clean up the corresponding weak reference
self.remove_block(request_id, &block) for (block_hash, rc) in token_seq {
drop(rc);
self.try_remove_block(&block_hash);
} }
self.active_blocks self.active_blocks()
} }
/// Force expiry of stale requests if the timer has elapsed /// Force expiry of stale requests if the timer has elapsed
...@@ -283,7 +282,7 @@ enum UpdateSequences { ...@@ -283,7 +282,7 @@ enum UpdateSequences {
pub struct ActiveSequencesMultiWorker { pub struct ActiveSequencesMultiWorker {
senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>, senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
request_to_worker: Arc<DashMap<RequestId, WorkerId>>, request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
handles: Arc<DashMap<WorkerId, tokio::task::JoinHandle<()>>>, handles: Arc<DashMap<WorkerId, std::thread::JoinHandle<()>>>,
block_size: usize, block_size: usize,
component: Component, component: Component,
router_id: Uuid, router_id: Uuid,
...@@ -360,92 +359,98 @@ impl ActiveSequencesMultiWorker { ...@@ -360,92 +359,98 @@ impl ActiveSequencesMultiWorker {
/// Helper method to start a worker task /// Helper method to start a worker task
fn start_worker( fn start_worker(
block_size: usize, block_size: usize,
cancel_token: CancellationToken, // Add cancellation token parameter cancel_token: CancellationToken,
) -> ( ) -> (
tokio::sync::mpsc::UnboundedSender<UpdateSequences>, tokio::sync::mpsc::UnboundedSender<UpdateSequences>,
tokio::task::JoinHandle<()>, std::thread::JoinHandle<()>,
) { ) {
let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel(); let (request_tx, request_rx) = tokio::sync::mpsc::unbounded_channel();
let handle = tokio::spawn(async move { let handle = std::thread::spawn(move || {
let mut active_sequences = ActiveSequences::new(block_size); // Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_current_thread()
loop { .enable_all()
tokio::select! { .build()
// Handle incoming commands .unwrap();
command = request_rx.recv() => {
match command { runtime.block_on(async move {
Some(command) => { let mut active_sequences = ActiveSequences::new(block_size);
match command { let mut request_rx = request_rx;
UpdateSequences::AddRequest {
request_id, loop {
token_sequence, tokio::select! {
isl, command = request_rx.recv() => {
overlap, let Some(command) = command else {
resp_tx, break;
} => { };
let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap);
let _ = resp_tx.send(removed); match command {
} UpdateSequences::AddRequest {
UpdateSequences::Free { request_id } => { request_id,
active_sequences.free(&request_id); token_sequence,
} isl,
UpdateSequences::MarkPrefillCompleted { request_id } => { overlap,
active_sequences.mark_prefill_completed(&request_id); resp_tx,
} } => {
UpdateSequences::NewBlocks { let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap);
token_sequence, let _ = resp_tx.send(removed);
resp_tx, }
} => { UpdateSequences::Free { request_id } => {
let new_blocks = active_sequences.new_blocks(&token_sequence); active_sequences.free(&request_id);
let _ = resp_tx.send(new_blocks); }
} UpdateSequences::MarkPrefillCompleted { request_id } => {
UpdateSequences::PotentialBlocks { active_sequences.mark_prefill_completed(&request_id);
token_sequence, }
resp_tx, UpdateSequences::NewBlocks {
} => { token_sequence,
let potential_blocks = active_sequences.potential_blocks(&token_sequence); resp_tx,
let _ = resp_tx.send(potential_blocks); } => {
} let new_blocks = active_sequences.new_blocks(&token_sequence);
UpdateSequences::PotentialBlocksAndTokens { let _ = resp_tx.send(new_blocks);
token_sequence, }
UpdateSequences::PotentialBlocks {
token_sequence,
resp_tx,
} => {
let potential_blocks = active_sequences.potential_blocks(&token_sequence);
let _ = resp_tx.send(potential_blocks);
}
UpdateSequences::PotentialBlocksAndTokens {
token_sequence,
isl,
overlap,
resp_tx,
} => {
let potential_tokens = active_sequences.potential_blocks_and_tokens(
token_sequence.as_ref().map(|v| v.as_slice()),
isl, isl,
overlap, overlap,
resp_tx, );
} => { let _ = resp_tx.send(potential_tokens);
let potential_tokens = active_sequences.potential_blocks_and_tokens( }
token_sequence.as_ref().map(|v| v.as_slice()), UpdateSequences::ActiveBlocks { resp_tx } => {
isl, let active_blocks = active_sequences.active_blocks();
overlap, let _ = resp_tx.send(active_blocks);
); }
let _ = resp_tx.send(potential_tokens); UpdateSequences::ActiveTokens { resp_tx } => {
} let active_tokens = active_sequences.active_tokens();
UpdateSequences::ActiveBlocks { resp_tx } => { let _ = resp_tx.send(active_tokens);
let active_blocks = active_sequences.active_blocks(); }
let _ = resp_tx.send(active_blocks); UpdateSequences::Shutdown => {
} break;
UpdateSequences::ActiveTokens { resp_tx } => {
let active_tokens = active_sequences.active_tokens();
let _ = resp_tx.send(active_tokens);
}
UpdateSequences::Shutdown => {
break;
}
} }
}
None => {
// Channel closed, exit
break;
} }
} }
} // Handle cancellation
// Handle cancellation _ = cancel_token.cancelled() => {
_ = cancel_token.cancelled() => { tracing::debug!("Worker task cancelled");
tracing::debug!("Worker task cancelled"); break;
break; }
} }
} }
} });
tracing::debug!("ActiveSequences worker task completed");
}); });
(request_tx, handle) (request_tx, handle)
...@@ -560,9 +565,7 @@ impl ActiveSequencesMultiWorker { ...@@ -560,9 +565,7 @@ impl ActiveSequencesMultiWorker {
if let Some((_, sender)) = self.senders.remove(worker_id) { if let Some((_, sender)) = self.senders.remove(worker_id) {
let _ = sender.send(UpdateSequences::Shutdown); let _ = sender.send(UpdateSequences::Shutdown);
} }
if let Some((_, handle)) = self.handles.remove(worker_id) { self.handles.remove(worker_id);
handle.abort();
}
// Clean up request_to_worker mappings for this worker // Clean up request_to_worker mappings for this worker
self.request_to_worker self.request_to_worker
...@@ -856,11 +859,6 @@ impl Drop for ActiveSequencesMultiWorker { ...@@ -856,11 +859,6 @@ impl Drop for ActiveSequencesMultiWorker {
for entry in self.senders.iter() { for entry in self.senders.iter() {
let _ = entry.value().send(UpdateSequences::Shutdown); let _ = entry.value().send(UpdateSequences::Shutdown);
} }
// Abort all tasks
for entry in self.handles.iter() {
entry.value().abort();
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment