Unverified Commit 227846f2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: use Rc to do reference counting in Router slot manager (#3545)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent b94ecd16
......@@ -801,57 +801,50 @@ impl KvIndexer {
let cancel_clone = token.clone();
let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(1) // Single-threaded environment
// Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let local_set = tokio::task::LocalSet::new();
runtime.block_on(local_set.run_until(async move {
tokio::task::spawn_local(async move {
let cancel = cancel_clone;
let mut match_rx = match_rx;
let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx;
let mut dump_rx = dump_rx;
let mut trie = RadixTree::new_with_frequency(expiration_duration);
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::debug!("KvCacheIndexer progress loop shutting down");
return;
}
runtime.block_on(async move {
let cancel = cancel_clone;
let mut match_rx = match_rx;
let mut event_rx = event_rx;
let mut remove_worker_rx = remove_worker_rx;
let mut dump_rx = dump_rx;
let mut trie = RadixTree::new_with_frequency(expiration_duration);
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::debug!("KvCacheIndexer progress loop shutting down");
return;
}
Some(worker) = remove_worker_rx.recv() => {
trie.remove_worker(worker);
}
Some(worker) = remove_worker_rx.recv() => {
trie.remove_worker(worker);
}
Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event);
metrics.increment_event_applied(event_type, result);
}
Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event);
metrics.increment_event_applied(event_type, result);
}
Some(dump_req) = dump_rx.recv() => {
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
Some(dump_req) = dump_rx.recv() => {
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
Some(req) = match_rx.recv() => {
let matches = trie.find_matches(req.sequence, req.early_exit);
let _ = req.resp.send(matches);
}
Some(req) = match_rx.recv() => {
let matches = trie.find_matches(req.sequence, req.early_exit);
let _ = req.resp.send(matches);
}
}
})
.await
.unwrap()
}));
}
});
tracing::debug!("KvCacheIndexer task completed");
});
......@@ -1058,54 +1051,47 @@ impl KvIndexerSharded {
remove_worker_tx.push(shard_remove_worker_tx);
dump_tx.push(shard_dump_tx); // Store dump sender
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
tasks.push(std::thread::spawn(move || {
let local_set = tokio::task::LocalSet::new();
runtime.block_on(local_set.run_until(async move {
tokio::task::spawn_local(async move {
let mut trie = RadixTree::new_with_frequency(expiration_duration);
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::trace!("KvCacheIndexer progress loop shutting down");
return;
}
runtime.block_on(async move {
let mut trie = RadixTree::new_with_frequency(expiration_duration);
loop {
tokio::select! {
biased;
Some(worker) = shard_remove_worker_rx.recv() => {
trie.remove_worker(worker);
}
_ = cancel.cancelled() => {
tracing::trace!("KvCacheIndexer progress loop shutting down");
return;
}
Some(event) = shard_event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event);
metrics.increment_event_applied(event_type, result);
}
Some(worker) = shard_remove_worker_rx.recv() => {
trie.remove_worker(worker);
}
Some(dump_req) = shard_dump_rx.recv() => {
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
Some(event) = shard_event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event);
metrics.increment_event_applied(event_type, result);
}
Ok(req) = shard_broadcast_rx.recv() => {
let matches = trie.find_matches(req.sequence, req.early_exit);
if let Err(e) = req.resp.send(matches).await {
tracing::trace!("Failed to send match response: {:?}", e);
}
Some(dump_req) = shard_dump_rx.recv() => {
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
Ok(req) = shard_broadcast_rx.recv() => {
let matches = trie.find_matches(req.sequence, req.early_exit);
if let Err(e) = req.resp.send(matches).await {
tracing::trace!("Failed to send match response: {:?}", e);
}
}
}
})
.await
.unwrap()
}));
}
});
tracing::debug!("KvCacheIndexer task completed");
}));
......
......@@ -33,6 +33,7 @@ use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt;
use std::collections::{HashMap, HashSet};
use std::rc::{Rc, Weak};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
......@@ -51,18 +52,15 @@ pub type RequestId = String;
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<SequenceHash>>,
active_seqs: HashMap<RequestId, Vec<(SequenceHash, Rc<()>)>>,
prefill_tokens: HashMap<RequestId, usize>,
unique_blocks: HashMap<SequenceHash, HashSet<RequestId>>,
unique_blocks: HashMap<SequenceHash, Weak<()>>,
#[getter(copy)]
block_size: usize,
#[getter(copy)]
active_blocks: usize,
#[getter(copy)]
active_tokens: usize,
......@@ -84,39 +82,36 @@ impl ActiveSequences {
prefill_tokens: HashMap::new(),
unique_blocks: HashMap::new(),
block_size,
active_blocks: 0,
active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION,
expiry_requests: HashSet::new(),
}
}
fn add_block(&mut self, request_id: RequestId, block: &SequenceHash) {
let is_new_block = !self.unique_blocks.contains_key(block);
self.unique_blocks
.entry(*block)
.or_default()
.insert(request_id.clone());
if is_new_block {
self.active_blocks += 1;
fn touch_block(&mut self, block: &SequenceHash) -> Rc<()> {
if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
{
return rc;
}
}
fn remove_block(&mut self, request_id: &RequestId, block: &SequenceHash) {
let Some(request_ids) = self.unique_blocks.get_mut(block) else {
return;
};
let rc = Rc::new(());
self.unique_blocks.insert(*block, Rc::downgrade(&rc));
rc
}
// Remove the unique block if no more requests using it
request_ids.retain(|w| w != request_id);
if request_ids.is_empty() {
self.active_blocks -= 1;
fn try_remove_block(&mut self, block: &SequenceHash) {
if let Some(weak) = self.unique_blocks.get(block)
&& weak.strong_count() == 0
{
self.unique_blocks.remove(block);
}
}
pub fn active_blocks(&self) -> usize {
self.unique_blocks.len()
}
/// Add a new request with its initial tokens
/// Returns the set of expired request IDs that were removed during cleanup
pub fn add_request(
......@@ -140,10 +135,12 @@ impl ActiveSequences {
self.active_tokens += prefill_tokens;
if let Some(sequence) = token_sequence {
for block in &sequence {
self.add_block(request_id.clone(), block);
}
self.active_seqs.insert(request_id.clone(), sequence);
let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence
.iter()
.map(|block| (*block, self.touch_block(block)))
.collect();
self.active_seqs
.insert(request_id.clone(), sequence_with_refs);
} else {
// dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new());
......@@ -174,9 +171,9 @@ impl ActiveSequences {
overlap: u32,
) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks
self.new_blocks(token_seq) + self.active_blocks()
} else {
self.active_blocks
self.active_blocks()
};
let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
(potential_blocks, potential_tokens)
......@@ -193,7 +190,7 @@ impl ActiveSequences {
/// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks
self.new_blocks(token_sequence) + self.active_blocks()
}
/// Free all blocks associated with a request
......@@ -207,15 +204,17 @@ impl ActiveSequences {
Some(seq) => seq,
None => {
tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks;
return self.active_blocks();
}
};
for block in token_seq {
self.remove_block(request_id, &block)
// Drop each Rc reference, then clean up the corresponding weak reference
for (block_hash, rc) in token_seq {
drop(rc);
self.try_remove_block(&block_hash);
}
self.active_blocks
self.active_blocks()
}
/// Force expiry of stale requests if the timer has elapsed
......@@ -283,7 +282,7 @@ enum UpdateSequences {
pub struct ActiveSequencesMultiWorker {
senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
handles: Arc<DashMap<WorkerId, tokio::task::JoinHandle<()>>>,
handles: Arc<DashMap<WorkerId, std::thread::JoinHandle<()>>>,
block_size: usize,
component: Component,
router_id: Uuid,
......@@ -360,92 +359,98 @@ impl ActiveSequencesMultiWorker {
/// Helper method to start a worker task
fn start_worker(
block_size: usize,
cancel_token: CancellationToken, // Add cancellation token parameter
cancel_token: CancellationToken,
) -> (
tokio::sync::mpsc::UnboundedSender<UpdateSequences>,
tokio::task::JoinHandle<()>,
std::thread::JoinHandle<()>,
) {
let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel();
let handle = tokio::spawn(async move {
let mut active_sequences = ActiveSequences::new(block_size);
loop {
tokio::select! {
// Handle incoming commands
command = request_rx.recv() => {
match command {
Some(command) => {
match command {
UpdateSequences::AddRequest {
request_id,
token_sequence,
isl,
overlap,
resp_tx,
} => {
let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap);
let _ = resp_tx.send(removed);
}
UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id);
}
UpdateSequences::MarkPrefillCompleted { request_id } => {
active_sequences.mark_prefill_completed(&request_id);
}
UpdateSequences::NewBlocks {
token_sequence,
resp_tx,
} => {
let new_blocks = active_sequences.new_blocks(&token_sequence);
let _ = resp_tx.send(new_blocks);
}
UpdateSequences::PotentialBlocks {
token_sequence,
resp_tx,
} => {
let potential_blocks = active_sequences.potential_blocks(&token_sequence);
let _ = resp_tx.send(potential_blocks);
}
UpdateSequences::PotentialBlocksAndTokens {
token_sequence,
let (request_tx, request_rx) = tokio::sync::mpsc::unbounded_channel();
let handle = std::thread::spawn(move || {
// Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async move {
let mut active_sequences = ActiveSequences::new(block_size);
let mut request_rx = request_rx;
loop {
tokio::select! {
command = request_rx.recv() => {
let Some(command) = command else {
break;
};
match command {
UpdateSequences::AddRequest {
request_id,
token_sequence,
isl,
overlap,
resp_tx,
} => {
let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap);
let _ = resp_tx.send(removed);
}
UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id);
}
UpdateSequences::MarkPrefillCompleted { request_id } => {
active_sequences.mark_prefill_completed(&request_id);
}
UpdateSequences::NewBlocks {
token_sequence,
resp_tx,
} => {
let new_blocks = active_sequences.new_blocks(&token_sequence);
let _ = resp_tx.send(new_blocks);
}
UpdateSequences::PotentialBlocks {
token_sequence,
resp_tx,
} => {
let potential_blocks = active_sequences.potential_blocks(&token_sequence);
let _ = resp_tx.send(potential_blocks);
}
UpdateSequences::PotentialBlocksAndTokens {
token_sequence,
isl,
overlap,
resp_tx,
} => {
let potential_tokens = active_sequences.potential_blocks_and_tokens(
token_sequence.as_ref().map(|v| v.as_slice()),
isl,
overlap,
resp_tx,
} => {
let potential_tokens = active_sequences.potential_blocks_and_tokens(
token_sequence.as_ref().map(|v| v.as_slice()),
isl,
overlap,
);
let _ = resp_tx.send(potential_tokens);
}
UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks);
}
UpdateSequences::ActiveTokens { resp_tx } => {
let active_tokens = active_sequences.active_tokens();
let _ = resp_tx.send(active_tokens);
}
UpdateSequences::Shutdown => {
break;
}
);
let _ = resp_tx.send(potential_tokens);
}
UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks);
}
UpdateSequences::ActiveTokens { resp_tx } => {
let active_tokens = active_sequences.active_tokens();
let _ = resp_tx.send(active_tokens);
}
UpdateSequences::Shutdown => {
break;
}
}
None => {
// Channel closed, exit
break;
}
}
}
// Handle cancellation
_ = cancel_token.cancelled() => {
tracing::debug!("Worker task cancelled");
break;
// Handle cancellation
_ = cancel_token.cancelled() => {
tracing::debug!("Worker task cancelled");
break;
}
}
}
}
});
tracing::debug!("ActiveSequences worker task completed");
});
(request_tx, handle)
......@@ -560,9 +565,7 @@ impl ActiveSequencesMultiWorker {
if let Some((_, sender)) = self.senders.remove(worker_id) {
let _ = sender.send(UpdateSequences::Shutdown);
}
if let Some((_, handle)) = self.handles.remove(worker_id) {
handle.abort();
}
self.handles.remove(worker_id);
// Clean up request_to_worker mappings for this worker
self.request_to_worker
......@@ -856,11 +859,6 @@ impl Drop for ActiveSequencesMultiWorker {
for entry in self.senders.iter() {
let _ = entry.value().send(UpdateSequences::Shutdown);
}
// Abort all tasks
for entry in self.handles.iter() {
entry.value().abort();
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment