Unverified Commit 227846f2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: use Rc to do reference counting in Router slot manager (#3545)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent b94ecd16
......@@ -801,17 +801,13 @@ impl KvIndexer {
let cancel_clone = token.clone();
let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(1) // Single-threaded environment
// Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let local_set = tokio::task::LocalSet::new();
runtime.block_on(local_set.run_until(async move {
tokio::task::spawn_local(async move {
runtime.block_on(async move {
let cancel = cancel_clone;
let mut match_rx = match_rx;
let mut event_rx = event_rx;
......@@ -848,10 +844,7 @@ impl KvIndexer {
}
}
}
})
.await
.unwrap()
}));
});
tracing::debug!("KvCacheIndexer task completed");
});
......@@ -1058,17 +1051,13 @@ impl KvIndexerSharded {
remove_worker_tx.push(shard_remove_worker_tx);
dump_tx.push(shard_dump_tx); // Store dump sender
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
tasks.push(std::thread::spawn(move || {
let local_set = tokio::task::LocalSet::new();
runtime.block_on(local_set.run_until(async move {
tokio::task::spawn_local(async move {
runtime.block_on(async move {
let mut trie = RadixTree::new_with_frequency(expiration_duration);
loop {
tokio::select! {
......@@ -1102,10 +1091,7 @@ impl KvIndexerSharded {
}
}
}
})
.await
.unwrap()
}));
});
tracing::debug!("KvCacheIndexer task completed");
}));
......
......@@ -33,6 +33,7 @@ use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt;
use std::collections::{HashMap, HashSet};
use std::rc::{Rc, Weak};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
......@@ -51,18 +52,15 @@ pub type RequestId = String;
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<SequenceHash>>,
active_seqs: HashMap<RequestId, Vec<(SequenceHash, Rc<()>)>>,
prefill_tokens: HashMap<RequestId, usize>,
unique_blocks: HashMap<SequenceHash, HashSet<RequestId>>,
unique_blocks: HashMap<SequenceHash, Weak<()>>,
#[getter(copy)]
block_size: usize,
#[getter(copy)]
active_blocks: usize,
#[getter(copy)]
active_tokens: usize,
......@@ -84,39 +82,36 @@ impl ActiveSequences {
prefill_tokens: HashMap::new(),
unique_blocks: HashMap::new(),
block_size,
active_blocks: 0,
active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION,
expiry_requests: HashSet::new(),
}
}
fn add_block(&mut self, request_id: RequestId, block: &SequenceHash) {
let is_new_block = !self.unique_blocks.contains_key(block);
self.unique_blocks
.entry(*block)
.or_default()
.insert(request_id.clone());
if is_new_block {
self.active_blocks += 1;
}
fn touch_block(&mut self, block: &SequenceHash) -> Rc<()> {
if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
{
return rc;
}
fn remove_block(&mut self, request_id: &RequestId, block: &SequenceHash) {
let Some(request_ids) = self.unique_blocks.get_mut(block) else {
return;
};
let rc = Rc::new(());
self.unique_blocks.insert(*block, Rc::downgrade(&rc));
rc
}
// Remove the unique block if no more requests using it
request_ids.retain(|w| w != request_id);
if request_ids.is_empty() {
self.active_blocks -= 1;
fn try_remove_block(&mut self, block: &SequenceHash) {
if let Some(weak) = self.unique_blocks.get(block)
&& weak.strong_count() == 0
{
self.unique_blocks.remove(block);
}
}
pub fn active_blocks(&self) -> usize {
self.unique_blocks.len()
}
/// Add a new request with its initial tokens
/// Returns the set of expired request IDs that were removed during cleanup
pub fn add_request(
......@@ -140,10 +135,12 @@ impl ActiveSequences {
self.active_tokens += prefill_tokens;
if let Some(sequence) = token_sequence {
for block in &sequence {
self.add_block(request_id.clone(), block);
}
self.active_seqs.insert(request_id.clone(), sequence);
let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence
.iter()
.map(|block| (*block, self.touch_block(block)))
.collect();
self.active_seqs
.insert(request_id.clone(), sequence_with_refs);
} else {
// dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new());
......@@ -174,9 +171,9 @@ impl ActiveSequences {
overlap: u32,
) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks
self.new_blocks(token_seq) + self.active_blocks()
} else {
self.active_blocks
self.active_blocks()
};
let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
(potential_blocks, potential_tokens)
......@@ -193,7 +190,7 @@ impl ActiveSequences {
/// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks
self.new_blocks(token_sequence) + self.active_blocks()
}
/// Free all blocks associated with a request
......@@ -207,15 +204,17 @@ impl ActiveSequences {
Some(seq) => seq,
None => {
tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks;
return self.active_blocks();
}
};
for block in token_seq {
self.remove_block(request_id, &block)
// Drop each Rc reference, then clean up the corresponding weak reference
for (block_hash, rc) in token_seq {
drop(rc);
self.try_remove_block(&block_hash);
}
self.active_blocks
self.active_blocks()
}
/// Force expiry of stale requests if the timer has elapsed
......@@ -283,7 +282,7 @@ enum UpdateSequences {
pub struct ActiveSequencesMultiWorker {
senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
handles: Arc<DashMap<WorkerId, tokio::task::JoinHandle<()>>>,
handles: Arc<DashMap<WorkerId, std::thread::JoinHandle<()>>>,
block_size: usize,
component: Component,
router_id: Uuid,
......@@ -360,22 +359,31 @@ impl ActiveSequencesMultiWorker {
/// Helper method to start a worker task
fn start_worker(
block_size: usize,
cancel_token: CancellationToken, // Add cancellation token parameter
cancel_token: CancellationToken,
) -> (
tokio::sync::mpsc::UnboundedSender<UpdateSequences>,
tokio::task::JoinHandle<()>,
std::thread::JoinHandle<()>,
) {
let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel();
let (request_tx, request_rx) = tokio::sync::mpsc::unbounded_channel();
let handle = std::thread::spawn(move || {
// Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let handle = tokio::spawn(async move {
runtime.block_on(async move {
let mut active_sequences = ActiveSequences::new(block_size);
let mut request_rx = request_rx;
loop {
tokio::select! {
// Handle incoming commands
command = request_rx.recv() => {
match command {
Some(command) => {
let Some(command) = command else {
break;
};
match command {
UpdateSequences::AddRequest {
request_id,
......@@ -433,12 +441,6 @@ impl ActiveSequencesMultiWorker {
}
}
}
None => {
// Channel closed, exit
break;
}
}
}
// Handle cancellation
_ = cancel_token.cancelled() => {
tracing::debug!("Worker task cancelled");
......@@ -448,6 +450,9 @@ impl ActiveSequencesMultiWorker {
}
});
tracing::debug!("ActiveSequences worker task completed");
});
(request_tx, handle)
}
......@@ -560,9 +565,7 @@ impl ActiveSequencesMultiWorker {
if let Some((_, sender)) = self.senders.remove(worker_id) {
let _ = sender.send(UpdateSequences::Shutdown);
}
if let Some((_, handle)) = self.handles.remove(worker_id) {
handle.abort();
}
self.handles.remove(worker_id);
// Clean up request_to_worker mappings for this worker
self.request_to_worker
......@@ -856,11 +859,6 @@ impl Drop for ActiveSequencesMultiWorker {
for entry in self.senders.iter() {
let _ = entry.value().send(UpdateSequences::Shutdown);
}
// Abort all tasks
for entry in self.handles.iter() {
entry.value().abort();
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment