"deploy/snapshot/vscode:/vscode.git/clone" did not exist on "ed4d8068f3309b965b2f5bc09911093f84d7aa92"
Unverified Commit 227846f2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: use Rc to do reference counting in Router slot manager (#3545)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent b94ecd16
...@@ -801,17 +801,13 @@ impl KvIndexer { ...@@ -801,17 +801,13 @@ impl KvIndexer {
let cancel_clone = token.clone(); let cancel_clone = token.clone();
let task = std::thread::spawn(move || { let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread // Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_multi_thread() let runtime = tokio::runtime::Builder::new_current_thread()
.worker_threads(1) // Single-threaded environment
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
let local_set = tokio::task::LocalSet::new(); runtime.block_on(async move {
runtime.block_on(local_set.run_until(async move {
tokio::task::spawn_local(async move {
let cancel = cancel_clone; let cancel = cancel_clone;
let mut match_rx = match_rx; let mut match_rx = match_rx;
let mut event_rx = event_rx; let mut event_rx = event_rx;
...@@ -848,10 +844,7 @@ impl KvIndexer { ...@@ -848,10 +844,7 @@ impl KvIndexer {
} }
} }
} }
}) });
.await
.unwrap()
}));
tracing::debug!("KvCacheIndexer task completed"); tracing::debug!("KvCacheIndexer task completed");
}); });
...@@ -1058,17 +1051,13 @@ impl KvIndexerSharded { ...@@ -1058,17 +1051,13 @@ impl KvIndexerSharded {
remove_worker_tx.push(shard_remove_worker_tx); remove_worker_tx.push(shard_remove_worker_tx);
dump_tx.push(shard_dump_tx); // Store dump sender dump_tx.push(shard_dump_tx); // Store dump sender
let runtime = tokio::runtime::Builder::new_multi_thread() let runtime = tokio::runtime::Builder::new_current_thread()
.worker_threads(1)
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
tasks.push(std::thread::spawn(move || { tasks.push(std::thread::spawn(move || {
let local_set = tokio::task::LocalSet::new(); runtime.block_on(async move {
runtime.block_on(local_set.run_until(async move {
tokio::task::spawn_local(async move {
let mut trie = RadixTree::new_with_frequency(expiration_duration); let mut trie = RadixTree::new_with_frequency(expiration_duration);
loop { loop {
tokio::select! { tokio::select! {
...@@ -1102,10 +1091,7 @@ impl KvIndexerSharded { ...@@ -1102,10 +1091,7 @@ impl KvIndexerSharded {
} }
} }
} }
}) });
.await
.unwrap()
}));
tracing::debug!("KvCacheIndexer task completed"); tracing::debug!("KvCacheIndexer task completed");
})); }));
......
...@@ -33,6 +33,7 @@ use dynamo_runtime::traits::DistributedRuntimeProvider; ...@@ -33,6 +33,7 @@ use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber}; use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt; use futures::StreamExt;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::rc::{Rc, Weak};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
...@@ -51,18 +52,15 @@ pub type RequestId = String; ...@@ -51,18 +52,15 @@ pub type RequestId = String;
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache /// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)] #[derive(Debug, Getters)]
pub struct ActiveSequences { pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<SequenceHash>>, active_seqs: HashMap<RequestId, Vec<(SequenceHash, Rc<()>)>>,
prefill_tokens: HashMap<RequestId, usize>, prefill_tokens: HashMap<RequestId, usize>,
unique_blocks: HashMap<SequenceHash, HashSet<RequestId>>, unique_blocks: HashMap<SequenceHash, Weak<()>>,
#[getter(copy)] #[getter(copy)]
block_size: usize, block_size: usize,
#[getter(copy)]
active_blocks: usize,
#[getter(copy)] #[getter(copy)]
active_tokens: usize, active_tokens: usize,
...@@ -84,39 +82,36 @@ impl ActiveSequences { ...@@ -84,39 +82,36 @@ impl ActiveSequences {
prefill_tokens: HashMap::new(), prefill_tokens: HashMap::new(),
unique_blocks: HashMap::new(), unique_blocks: HashMap::new(),
block_size, block_size,
active_blocks: 0,
active_tokens: 0, active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION, expiry_timer: Instant::now() + EXPIRY_DURATION,
expiry_requests: HashSet::new(), expiry_requests: HashSet::new(),
} }
} }
fn add_block(&mut self, request_id: RequestId, block: &SequenceHash) { fn touch_block(&mut self, block: &SequenceHash) -> Rc<()> {
let is_new_block = !self.unique_blocks.contains_key(block); if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
self.unique_blocks {
.entry(*block) return rc;
.or_default()
.insert(request_id.clone());
if is_new_block {
self.active_blocks += 1;
}
} }
fn remove_block(&mut self, request_id: &RequestId, block: &SequenceHash) { let rc = Rc::new(());
let Some(request_ids) = self.unique_blocks.get_mut(block) else { self.unique_blocks.insert(*block, Rc::downgrade(&rc));
return; rc
}; }
// Remove the unique block if no more requests using it fn try_remove_block(&mut self, block: &SequenceHash) {
request_ids.retain(|w| w != request_id); if let Some(weak) = self.unique_blocks.get(block)
if request_ids.is_empty() { && weak.strong_count() == 0
self.active_blocks -= 1; {
self.unique_blocks.remove(block); self.unique_blocks.remove(block);
} }
} }
pub fn active_blocks(&self) -> usize {
self.unique_blocks.len()
}
/// Add a new request with its initial tokens /// Add a new request with its initial tokens
/// Returns the set of expired request IDs that were removed during cleanup /// Returns the set of expired request IDs that were removed during cleanup
pub fn add_request( pub fn add_request(
...@@ -140,10 +135,12 @@ impl ActiveSequences { ...@@ -140,10 +135,12 @@ impl ActiveSequences {
self.active_tokens += prefill_tokens; self.active_tokens += prefill_tokens;
if let Some(sequence) = token_sequence { if let Some(sequence) = token_sequence {
for block in &sequence { let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence
self.add_block(request_id.clone(), block); .iter()
} .map(|block| (*block, self.touch_block(block)))
self.active_seqs.insert(request_id.clone(), sequence); .collect();
self.active_seqs
.insert(request_id.clone(), sequence_with_refs);
} else { } else {
// dummy empty sequence // dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new()); self.active_seqs.insert(request_id.clone(), Vec::new());
...@@ -174,9 +171,9 @@ impl ActiveSequences { ...@@ -174,9 +171,9 @@ impl ActiveSequences {
overlap: u32, overlap: u32,
) -> (usize, usize) { ) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence { let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks self.new_blocks(token_seq) + self.active_blocks()
} else { } else {
self.active_blocks self.active_blocks()
}; };
let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens; let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
(potential_blocks, potential_tokens) (potential_blocks, potential_tokens)
...@@ -193,7 +190,7 @@ impl ActiveSequences { ...@@ -193,7 +190,7 @@ impl ActiveSequences {
/// Return the total number of blocks that would be used if the token sequence was added /// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks /// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize { pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks self.new_blocks(token_sequence) + self.active_blocks()
} }
/// Free all blocks associated with a request /// Free all blocks associated with a request
...@@ -207,15 +204,17 @@ impl ActiveSequences { ...@@ -207,15 +204,17 @@ impl ActiveSequences {
Some(seq) => seq, Some(seq) => seq,
None => { None => {
tracing::warn!("Trying to free non-existent request {request_id}"); tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks; return self.active_blocks();
} }
}; };
for block in token_seq { // Drop each Rc reference, then clean up the corresponding weak reference
self.remove_block(request_id, &block) for (block_hash, rc) in token_seq {
drop(rc);
self.try_remove_block(&block_hash);
} }
self.active_blocks self.active_blocks()
} }
/// Force expiry of stale requests if the timer has elapsed /// Force expiry of stale requests if the timer has elapsed
...@@ -283,7 +282,7 @@ enum UpdateSequences { ...@@ -283,7 +282,7 @@ enum UpdateSequences {
pub struct ActiveSequencesMultiWorker { pub struct ActiveSequencesMultiWorker {
senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>, senders: Arc<DashMap<WorkerId, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
request_to_worker: Arc<DashMap<RequestId, WorkerId>>, request_to_worker: Arc<DashMap<RequestId, WorkerId>>,
handles: Arc<DashMap<WorkerId, tokio::task::JoinHandle<()>>>, handles: Arc<DashMap<WorkerId, std::thread::JoinHandle<()>>>,
block_size: usize, block_size: usize,
component: Component, component: Component,
router_id: Uuid, router_id: Uuid,
...@@ -360,22 +359,31 @@ impl ActiveSequencesMultiWorker { ...@@ -360,22 +359,31 @@ impl ActiveSequencesMultiWorker {
/// Helper method to start a worker task /// Helper method to start a worker task
fn start_worker( fn start_worker(
block_size: usize, block_size: usize,
cancel_token: CancellationToken, // Add cancellation token parameter cancel_token: CancellationToken,
) -> ( ) -> (
tokio::sync::mpsc::UnboundedSender<UpdateSequences>, tokio::sync::mpsc::UnboundedSender<UpdateSequences>,
tokio::task::JoinHandle<()>, std::thread::JoinHandle<()>,
) { ) {
let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel(); let (request_tx, request_rx) = tokio::sync::mpsc::unbounded_channel();
let handle = std::thread::spawn(move || {
// Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let handle = tokio::spawn(async move { runtime.block_on(async move {
let mut active_sequences = ActiveSequences::new(block_size); let mut active_sequences = ActiveSequences::new(block_size);
let mut request_rx = request_rx;
loop { loop {
tokio::select! { tokio::select! {
// Handle incoming commands
command = request_rx.recv() => { command = request_rx.recv() => {
match command { let Some(command) = command else {
Some(command) => { break;
};
match command { match command {
UpdateSequences::AddRequest { UpdateSequences::AddRequest {
request_id, request_id,
...@@ -433,12 +441,6 @@ impl ActiveSequencesMultiWorker { ...@@ -433,12 +441,6 @@ impl ActiveSequencesMultiWorker {
} }
} }
} }
None => {
// Channel closed, exit
break;
}
}
}
// Handle cancellation // Handle cancellation
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
tracing::debug!("Worker task cancelled"); tracing::debug!("Worker task cancelled");
...@@ -448,6 +450,9 @@ impl ActiveSequencesMultiWorker { ...@@ -448,6 +450,9 @@ impl ActiveSequencesMultiWorker {
} }
}); });
tracing::debug!("ActiveSequences worker task completed");
});
(request_tx, handle) (request_tx, handle)
} }
...@@ -560,9 +565,7 @@ impl ActiveSequencesMultiWorker { ...@@ -560,9 +565,7 @@ impl ActiveSequencesMultiWorker {
if let Some((_, sender)) = self.senders.remove(worker_id) { if let Some((_, sender)) = self.senders.remove(worker_id) {
let _ = sender.send(UpdateSequences::Shutdown); let _ = sender.send(UpdateSequences::Shutdown);
} }
if let Some((_, handle)) = self.handles.remove(worker_id) { self.handles.remove(worker_id);
handle.abort();
}
// Clean up request_to_worker mappings for this worker // Clean up request_to_worker mappings for this worker
self.request_to_worker self.request_to_worker
...@@ -856,11 +859,6 @@ impl Drop for ActiveSequencesMultiWorker { ...@@ -856,11 +859,6 @@ impl Drop for ActiveSequencesMultiWorker {
for entry in self.senders.iter() { for entry in self.senders.iter() {
let _ = entry.value().send(UpdateSequences::Shutdown); let _ = entry.value().send(UpdateSequences::Shutdown);
} }
// Abort all tasks
for entry in self.handles.iter() {
entry.value().abort();
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment