"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "2bf27924a1303678798c799ad5560089d9ab93c9"
Unverified Commit efa89448 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: de-async scheduler read paths and unblock decode output tracking (#6510)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 6fab12be
......@@ -494,14 +494,12 @@ impl KvRouter {
self.scheduler.worker_type()
}
pub async fn add_output_block(
pub fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.scheduler
.add_output_block(request_id, decay_fraction)
.await
self.scheduler.add_output_block(request_id, decay_fraction)
}
pub fn block_size(&self) -> u32 {
......@@ -541,8 +539,7 @@ impl KvRouter {
Ok(self
.scheduler
.get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
.await)
.get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores))
}
/// Dump all events from the indexer
......
......@@ -108,7 +108,6 @@ impl RequestGuard {
if let Err(e) = self
.chooser
.add_output_block(&self.context_id, decay_fraction)
.await
{
tracing::warn!(
"Failed to add output block for request {}: {e}",
......
......@@ -103,7 +103,7 @@ impl SchedulerQueue {
return;
};
if self.all_workers_busy(threshold).await {
if self.all_workers_busy(threshold) {
tracing::debug!("all workers busy, queueing request");
let entry = self.make_entry(request);
self.pending.lock().await.push(entry);
......@@ -121,7 +121,7 @@ impl SchedulerQueue {
};
loop {
if self.all_workers_busy(threshold).await {
if self.all_workers_busy(threshold) {
break;
}
let Some(entry) = self.pending.lock().await.pop() else {
......@@ -135,14 +135,11 @@ impl SchedulerQueue {
/// Run the full scheduling pipeline for a single request:
/// compute potential load → select worker → respond → book via add_request.
async fn schedule(&self, mut request: SchedulingRequest) {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens(
let (decode_blocks, prefill_tokens) = self.slots.potential_blocks_and_tokens(
request.token_seq.clone(),
request.isl_tokens,
request.overlaps.clone(),
)
.await;
);
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
......@@ -194,8 +191,8 @@ impl SchedulerQueue {
/// Check if all workers are busy based on threshold.
/// Returns true only if ALL workers exceed the threshold (no worker has capacity).
async fn all_workers_busy(&self, threshold: f64) -> bool {
let active_tokens = self.slots.active_tokens().await;
fn all_workers_busy(&self, threshold: f64) -> bool {
let active_tokens = self.slots.active_tokens();
let configs = self.workers_with_configs.borrow();
for (&worker_id, config) in configs.iter() {
......
......@@ -272,17 +272,16 @@ impl KvScheduler {
self.slots.worker_type()
}
pub async fn add_output_block(
pub fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
.await
}
pub async fn get_potential_loads(
pub fn get_potential_loads(
&self,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
......@@ -290,8 +289,7 @@ impl KvScheduler {
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
.await;
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps);
// Get all unique WorkerWithDpRank from both hashmaps
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
......
......@@ -9,12 +9,11 @@
//!
//! # Key Components
//!
//! - [`ActiveSequences`]: Single-threaded sequence manager that tracks active requests and their
//! - [`ActiveSequences`]: Per-worker sequence manager that tracks active requests and their
//! token sequences, managing shared KV cache blocks efficiently.
//!
//! - [`ActiveSequencesMultiWorker`]: Multi-threaded extension that distributes sequence management
//! across multiple worker threads, enabling parallel processing of requests while maintaining
//! consistency.
//! - [`ActiveSequencesMultiWorker`]: Multi-worker extension that stores per-worker
//! `ActiveSequences` in a shared `DashMap` for lock-free concurrent access.
//!
//! # Architecture
//!
......@@ -31,7 +30,6 @@ use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber};
use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet};
use std::rc::{Rc, Weak};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
......@@ -62,9 +60,6 @@ pub enum SequenceError {
#[error("Failed to publish event: {0}")]
PublishFailed(#[from] anyhow::Error),
#[error("Failed to send command to worker: channel closed")]
WorkerChannelClosed,
}
/// Duration after which stale requests are forcibly expired (5 minutes)
......@@ -87,14 +82,14 @@ pub struct SequenceRequest {
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<(SequenceHash, Rc<()>)>>,
active_seqs: HashMap<RequestId, Vec<(SequenceHash, Arc<()>)>>,
prefill_tokens: HashMap<RequestId, usize>,
/// Expected output tokens per request (used for resource estimation)
expected_output_tokens: HashMap<RequestId, u32>,
unique_blocks: HashMap<SequenceHash, Weak<()>>,
unique_blocks: HashMap<SequenceHash, std::sync::Weak<()>>,
/// Fractional block counts for blocks that are partially cached
/// When a block is in both unique_blocks and fractional_blocks,
......@@ -133,15 +128,15 @@ impl ActiveSequences {
}
}
fn touch_block(&mut self, block: &SequenceHash) -> Rc<()> {
fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
{
return rc;
}
let rc = Rc::new(());
self.unique_blocks.insert(*block, Rc::downgrade(&rc));
let rc = Arc::new(());
self.unique_blocks.insert(*block, Arc::downgrade(&rc));
rc
}
......@@ -177,7 +172,7 @@ impl ActiveSequences {
for (hash, rc) in blocks {
// A block with strong_count == 1 means only this request holds a reference
if Rc::strong_count(rc) == 1 {
if Arc::strong_count(rc) == 1 {
self.fractional_blocks.insert(*hash, fraction);
}
}
......@@ -214,7 +209,7 @@ impl ActiveSequences {
}
if let Some(sequence) = token_sequence {
let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence
let sequence_with_refs: Vec<(SequenceHash, Arc<()>)> = sequence
.iter()
.map(|block| (*block, self.touch_block(block)))
.collect();
......@@ -370,64 +365,16 @@ impl ActiveSequences {
}
}
enum UpdateSequences {
AddRequest {
request_id: RequestId,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
resp_tx: tokio::sync::oneshot::Sender<HashSet<RequestId>>,
},
Free {
request_id: RequestId,
},
MarkPrefillCompleted {
request_id: RequestId,
},
AddOutputBlock {
request_id: RequestId,
decay_fraction: Option<f64>,
resp_tx: tokio::sync::oneshot::Sender<bool>,
},
NewBlocks {
token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: tokio::sync::oneshot::Sender<usize>,
},
PotentialBlocks {
token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: tokio::sync::oneshot::Sender<usize>,
},
PotentialBlocksAndTokens {
token_sequence: Option<Arc<Vec<SequenceHash>>>,
isl: usize,
overlap: u32,
resp_tx: tokio::sync::oneshot::Sender<(usize, usize)>,
},
ActiveBlocks {
resp_tx: tokio::sync::oneshot::Sender<usize>,
},
ActiveTokens {
resp_tx: tokio::sync::oneshot::Sender<usize>,
},
Shutdown,
}
/// Multi-worker extension of ActiveSequences that distributes requests across multiple threads
/// Multi-worker extension of ActiveSequences using shared DashMap for lock-free concurrent access
pub struct ActiveSequencesMultiWorker {
senders: Arc<DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>,
workers: Arc<DashMap<WorkerWithDpRank, ActiveSequences>>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
request_to_lora: Arc<DashMap<RequestId, String>>,
handles: Arc<DashMap<WorkerWithDpRank, std::thread::JoinHandle<()>>>,
block_size: usize,
component: Component,
router_id: u64,
/// Publisher for sequence events
event_publisher: EventPublisher,
/// Publisher for metrics (namespace-scoped)
metrics_publisher: EventPublisher,
metrics_publisher: Arc<EventPublisher>,
replica_sync: bool,
/// Worker type for Prometheus metrics labeling ("prefill" or "decode")
worker_type: &'static str,
}
......@@ -442,37 +389,30 @@ impl ActiveSequencesMultiWorker {
) -> Result<Self> {
assert!(block_size > 1, "block_size must be greater than 1");
let senders = Arc::new(DashMap::new());
let handles = Arc::new(DashMap::new());
let workers = Arc::new(DashMap::new());
let request_to_worker = Arc::new(DashMap::new());
let request_to_lora = Arc::new(DashMap::new());
// Expand workers by their dp_rank
for (worker_id, config) in workers_with_configs {
let dp_size = config.data_parallel_size;
for dp_rank in 0..dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
// Create a child cancellation token from the component's runtime
let cancel_token = component.drt().runtime().child_token();
let (sender, handle) = Self::start_worker(block_size, cancel_token);
senders.insert(worker, sender);
handles.insert(worker, handle);
workers.insert(worker, ActiveSequences::new(block_size));
}
}
let event_publisher =
EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?;
let metrics_publisher =
EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?;
let metrics_publisher = Arc::new(
EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?,
);
let multi_worker = Self {
senders: senders.clone(),
workers: workers.clone(),
request_to_worker: request_to_worker.clone(),
request_to_lora: request_to_lora.clone(),
handles,
block_size,
component: component.clone(),
event_publisher,
metrics_publisher,
router_id,
......@@ -480,9 +420,8 @@ impl ActiveSequencesMultiWorker {
worker_type,
};
// Start the subscription loop only if replica_sync is enabled
if replica_sync {
let senders_clone = senders.clone();
let workers_clone = workers.clone();
let request_to_worker_clone = request_to_worker.clone();
let request_to_lora_clone = request_to_lora.clone();
let component_clone = component.clone();
......@@ -490,9 +429,8 @@ impl ActiveSequencesMultiWorker {
let cancel_token = component.drt().runtime().child_token();
tokio::spawn(async move {
// NATS subscription loop
if let Err(e) = Self::subscribe_to_events(
senders_clone,
workers_clone,
request_to_worker_clone,
request_to_lora_clone,
component_clone,
......@@ -509,120 +447,9 @@ impl ActiveSequencesMultiWorker {
Ok(multi_worker)
}
/// Helper method to start a worker task
fn start_worker(
block_size: usize,
cancel_token: CancellationToken,
) -> (
tokio::sync::mpsc::UnboundedSender<UpdateSequences>,
std::thread::JoinHandle<()>,
) {
let (request_tx, request_rx) = tokio::sync::mpsc::unbounded_channel();
let handle = std::thread::spawn(move || {
// Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async move {
let mut active_sequences = ActiveSequences::new(block_size);
let mut request_rx = request_rx;
loop {
tokio::select! {
command = request_rx.recv() => {
let Some(command) = command else {
break;
};
match command {
UpdateSequences::AddRequest {
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
resp_tx,
} => {
let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap, expected_output_tokens);
let _ = resp_tx.send(removed);
}
UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id);
}
UpdateSequences::MarkPrefillCompleted { request_id } => {
active_sequences.mark_prefill_completed(&request_id);
}
UpdateSequences::AddOutputBlock {
request_id,
decay_fraction,
resp_tx,
} => {
let success = active_sequences.add_output_block(&request_id, decay_fraction);
let _ = resp_tx.send(success);
}
UpdateSequences::NewBlocks {
token_sequence,
resp_tx,
} => {
let new_blocks = active_sequences.new_blocks(&token_sequence);
let _ = resp_tx.send(new_blocks);
}
UpdateSequences::PotentialBlocks {
token_sequence,
resp_tx,
} => {
let potential_blocks = active_sequences.potential_blocks(&token_sequence);
let _ = resp_tx.send(potential_blocks);
}
UpdateSequences::PotentialBlocksAndTokens {
token_sequence,
isl,
overlap,
resp_tx,
} => {
let potential_tokens = active_sequences.potential_blocks_and_tokens(
token_sequence.as_ref().map(|v| v.as_slice()),
isl,
overlap,
);
let _ = resp_tx.send(potential_tokens);
}
UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks);
}
UpdateSequences::ActiveTokens { resp_tx } => {
let active_tokens = active_sequences.active_tokens();
let _ = resp_tx.send(active_tokens);
}
UpdateSequences::Shutdown => {
break;
}
}
}
// Handle cancellation
_ = cancel_token.cancelled() => {
tracing::debug!("Worker task cancelled");
break;
}
}
}
});
tracing::debug!("ActiveSequences worker task completed");
});
(request_tx, handle)
}
/// Background task to subscribe to active sequence events and update all workers
async fn subscribe_to_events(
senders: Arc<
DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>,
>,
workers: Arc<DashMap<WorkerWithDpRank, ActiveSequences>>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
request_to_lora: Arc<DashMap<RequestId, String>>,
component: Component,
......@@ -635,10 +462,8 @@ impl ActiveSequencesMultiWorker {
loop {
tokio::select! {
// Handle incoming events
result = subscriber.next() => {
let Some(result) = result else {
// Stream ended
break;
};
......@@ -650,7 +475,6 @@ impl ActiveSequencesMultiWorker {
continue;
};
// Skip events emitted by itself
if event.router_id == router_id {
continue;
}
......@@ -664,22 +488,18 @@ impl ActiveSequencesMultiWorker {
} => {
request_to_worker.insert(event.request_id.clone(), event.worker);
// Store lora_name mapping if present
if let Some(ref lora_name) = event.lora_name {
request_to_lora.insert(event.request_id.clone(), lora_name.clone());
}
if let Some(sender) = senders.get(&event.worker) {
// For replicated events, we create a dummy response channel since we don't need to handle expired requests
let (resp_tx, _) = tokio::sync::oneshot::channel();
let _ = sender.send(UpdateSequences::AddRequest {
request_id: event.request_id.clone(),
token_sequence: token_sequence.clone(),
isl: *isl,
overlap: *overlap,
expected_output_tokens: *expected_output_tokens,
resp_tx,
});
if let Some(mut entry) = workers.get_mut(&event.worker) {
entry.add_request(
event.request_id.clone(),
token_sequence.clone(),
*isl,
*overlap,
*expected_output_tokens,
);
} else {
tracing::warn!(
"Worker {:?} not found, cannot process AddRequest",
......@@ -689,27 +509,21 @@ impl ActiveSequencesMultiWorker {
}
ActiveSequenceEventData::Free => {
if let Some((_, worker)) = request_to_worker.remove(&event.request_id)
&& let Some(sender) = senders.get(&worker)
&& let Some(mut entry) = workers.get_mut(&worker)
{
let _ = sender.send(UpdateSequences::Free {
request_id: event.request_id.clone(),
});
entry.free(&event.request_id);
}
// Clean up lora_name mapping
request_to_lora.remove(&event.request_id);
}
ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker) = request_to_worker.get(&event.request_id)
&& let Some(sender) = senders.get(&*worker)
&& let Some(mut entry) = workers.get_mut(&*worker)
{
let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
request_id: event.request_id.clone(),
});
entry.mark_prefill_completed(&event.request_id);
}
}
}
}
// Handle cancellation
_ = cancel_token.cancelled() => {
tracing::debug!("Subscription task cancelled");
break;
......@@ -723,9 +537,8 @@ impl ActiveSequencesMultiWorker {
/// Update the set of workers, adding and removing as needed
pub fn update_workers(&self, new_workers_with_configs: HashMap<u64, ModelRuntimeConfig>) {
let current_workers: HashSet<WorkerWithDpRank> =
self.senders.iter().map(|entry| *entry.key()).collect();
self.workers.iter().map(|entry| *entry.key()).collect();
// Expand new workers by their dp_rank
let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (worker_id, config) in &new_workers_with_configs {
let dp_size = config.data_parallel_size;
......@@ -740,17 +553,11 @@ impl ActiveSequencesMultiWorker {
let workers_to_add: Vec<WorkerWithDpRank> =
new_workers.difference(&current_workers).copied().collect();
// Remove workers (this will naturally remove all dp ranks for a worker_id)
for worker in &workers_to_remove {
tracing::warn!("Removing worker {:?}", worker);
// Send shutdown command to the worker
if let Some((_, sender)) = self.senders.remove(worker) {
let _ = sender.send(UpdateSequences::Shutdown);
}
self.handles.remove(worker);
self.workers.remove(worker);
// Collect request_ids to remove from request_to_lora
let requests_to_remove: Vec<RequestId> = self
.request_to_worker
.iter()
......@@ -758,26 +565,18 @@ impl ActiveSequencesMultiWorker {
.map(|entry| entry.key().clone())
.collect();
// Clean up request_to_worker mappings for this worker
self.request_to_worker
.retain(|_request_id, mapped_worker| mapped_worker != worker);
// Clean up request_to_lora mappings for removed requests
for request_id in requests_to_remove {
self.request_to_lora.remove(&request_id);
}
}
// Add new workers
for worker in &workers_to_add {
tracing::warn!("Adding worker {:?}", worker);
let (sender, handle) = Self::start_worker(
self.block_size,
self.component.drt().runtime().child_token(),
);
self.senders.insert(*worker, sender);
self.handles.insert(*worker, handle);
self.workers
.insert(*worker, ActiveSequences::new(self.block_size));
}
}
......@@ -792,15 +591,9 @@ impl ActiveSequencesMultiWorker {
lora_name,
} = req;
// Clone the sender upfront so we don't hold the DashMap Ref across
// the .await points below. Also eliminates the TOCTOU between
// contains_key and a later get().unwrap().
let sender = self
.senders
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
if !self.workers.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
if let Some(existing_worker) = self.request_to_worker.get(&request_id) {
return Err(SequenceError::DuplicateRequest {
......@@ -809,8 +602,6 @@ impl ActiveSequencesMultiWorker {
});
}
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
if self.replica_sync {
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
......@@ -833,38 +624,37 @@ impl ActiveSequencesMultiWorker {
self.request_to_lora.insert(request_id.clone(), lora);
}
sender
.send(UpdateSequences::AddRequest {
let removed_requests = {
let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
entry.add_request(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
resp_tx,
})
.map_err(|_| SequenceError::WorkerChannelClosed)?;
let removed_requests = resp_rx
.await
.map_err(|_| SequenceError::WorkerChannelClosed)?;
)
};
for expired_id in &removed_requests {
self.request_to_worker.remove(expired_id);
self.request_to_lora.remove(expired_id);
}
self.publish_active_load_for_worker(worker).await;
self.publish_active_load_for_worker(worker);
Ok(())
}
/// Send a command to the worker assigned to a request, optionally publishing
/// Send a mutation to the worker assigned to a request, optionally publishing
/// a replica-sync event and cleaning up request mappings afterward.
async fn send_to_request_worker(
async fn mutate_request_worker(
&self,
request_id: &RequestId,
event_data: ActiveSequenceEventData,
command_fn: impl FnOnce(RequestId) -> UpdateSequences,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId),
remove_mapping: bool,
) -> Result<(), SequenceError> {
let worker = self
......@@ -875,13 +665,6 @@ impl ActiveSequencesMultiWorker {
request_id: request_id.clone(),
})?;
let sender = self
.senders
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
if self.replica_sync {
let lora_name = self
.request_to_lora
......@@ -898,16 +681,20 @@ impl ActiveSequencesMultiWorker {
self.event_publisher.publish(&event).await?;
}
sender
.send(command_fn(request_id.clone()))
.map_err(|_| SequenceError::WorkerChannelClosed)?;
{
let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
mutate_fn(&mut entry, request_id);
}
if remove_mapping {
self.request_to_worker.remove(request_id);
self.request_to_lora.remove(request_id);
}
self.publish_active_load_for_worker(worker).await;
self.publish_active_load_for_worker(worker);
Ok(())
}
......@@ -922,10 +709,12 @@ impl ActiveSequencesMultiWorker {
return Ok(());
}
self.send_to_request_worker(
self.mutate_request_worker(
request_id,
ActiveSequenceEventData::Free,
|rid| UpdateSequences::Free { request_id: rid },
|seqs, rid| {
seqs.free(rid);
},
true,
)
.await
......@@ -939,10 +728,12 @@ impl ActiveSequencesMultiWorker {
&self,
request_id: &RequestId,
) -> Result<(), SequenceError> {
self.send_to_request_worker(
self.mutate_request_worker(
request_id,
ActiveSequenceEventData::MarkPrefillCompleted,
|rid| UpdateSequences::MarkPrefillCompleted { request_id: rid },
|seqs, rid| {
seqs.mark_prefill_completed(rid);
},
false,
)
.await
......@@ -952,7 +743,9 @@ impl ActiveSequencesMultiWorker {
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction represents how "temporary" the block is based on generation progress.
pub async fn add_output_block(
// TODO: output blocks are not replicated via replica_sync — add an
// ActiveSequenceEventData variant if cross-instance accuracy matters.
pub fn add_output_block(
&self,
request_id: &RequestId,
decay_fraction: Option<f64>,
......@@ -965,30 +758,13 @@ impl ActiveSequencesMultiWorker {
request_id: request_id.clone(),
})?;
// Clone sender upfront to avoid TOCTOU between contains_key and get().unwrap()
let sender = self
.senders
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
// Create response channel
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
// Send command to worker
sender
.send(UpdateSequences::AddOutputBlock {
request_id: request_id.clone(),
decay_fraction,
resp_tx,
})
.map_err(|_| SequenceError::WorkerChannelClosed)?;
// Wait for response
let success = resp_rx
.await
.map_err(|_| SequenceError::WorkerChannelClosed)?;
let success = {
let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
entry.add_output_block(request_id, decay_fraction)
};
if !success {
return Err(SequenceError::RequestNotFound {
......@@ -996,56 +772,22 @@ impl ActiveSequencesMultiWorker {
});
}
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
self.publish_active_load_for_worker(worker);
Ok(())
}
/// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad
async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
// Clone the sender and drop the DashMap Ref immediately.
// Holding a Ref across .await points can deadlock: if the task yields
// and update_workers() needs a write lock on the same shard, the
// runtime thread blocks forever.
let sender = {
let Some(entry) = self.senders.get(&worker) else {
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
/// The NATS publish is spawned as a background task to avoid blocking the caller.
fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
let (active_blocks, active_tokens) = {
let Some(entry) = self.workers.get(&worker) else {
tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad");
return;
};
entry.value().clone()
};
// Query active blocks
let (blocks_tx, blocks_rx) = tokio::sync::oneshot::channel();
if sender
.send(UpdateSequences::ActiveBlocks { resp_tx: blocks_tx })
.is_err()
{
tracing::warn!("Failed to send ActiveBlocks query to worker {worker:?}");
return;
}
// Query active tokens
let (tokens_tx, tokens_rx) = tokio::sync::oneshot::channel();
if sender
.send(UpdateSequences::ActiveTokens { resp_tx: tokens_tx })
.is_err()
{
tracing::warn!("Failed to send ActiveTokens query to worker {worker:?}");
return;
}
// Await both responses
let (active_blocks, active_tokens) = match tokio::join!(blocks_rx, tokens_rx) {
(Ok(blocks), Ok(tokens)) => (blocks, tokens),
_ => {
tracing::warn!("Failed to receive active blocks/tokens from worker {worker:?}");
return;
}
(entry.active_blocks(), entry.active_tokens())
};
// Update Prometheus gauges directly (router's own bookkeeping)
WORKER_LOAD_METRICS.observe(
worker.worker_id,
worker.dp_rank,
......@@ -1054,7 +796,6 @@ impl ActiveSequencesMultiWorker {
active_tokens,
);
// Also publish ActiveLoad to NATS for other subscribers (if NATS is available)
let active_load = ActiveLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
......@@ -1062,15 +803,19 @@ impl ActiveSequencesMultiWorker {
active_prefill_tokens: Some(active_tokens as u64),
};
if let Err(e) = self.metrics_publisher.publish(&active_load).await {
// This is expected if NATS is not available - the local gauge update above already succeeded
tracing::trace!("Failed to publish ActiveLoad to NATS for worker {worker:?}: {e:?}");
let publisher = self.metrics_publisher.clone();
tokio::spawn(async move {
if let Err(e) = publisher.publish(&active_load).await {
tracing::trace!(
"Failed to publish ActiveLoad to NATS for worker {worker:?}: {e:?}"
);
}
});
}
/// Get the number of workers
pub fn num_workers(&self) -> usize {
self.senders.len()
self.workers.len()
}
/// Get the worker type for this router ("prefill" or "decode").
......@@ -1079,80 +824,35 @@ impl ActiveSequencesMultiWorker {
self.worker_type
}
/// Generic method to query all workers with a given command
async fn query_workers<T: Send + 'static>(
&self,
token_sequence: Option<Vec<SequenceHash>>,
command_fn: impl Fn(
Option<Arc<Vec<SequenceHash>>>,
tokio::sync::oneshot::Sender<T>,
) -> UpdateSequences,
) -> HashMap<WorkerWithDpRank, T> {
let mut results = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new();
// Send queries to all workers in parallel
for entry in self.senders.iter() {
let worker = *entry.key();
let sender = entry.value();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
receivers.push((worker, resp_rx));
if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) {
tracing::error!("Failed to send command to worker {:?}: {}", worker, e);
}
}
// Collect results from all workers
for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
Ok(Ok(result)) => {
results.insert(worker, result);
}
Ok(Err(_)) => {
tracing::error!("Worker {:?} dropped response channel", worker);
}
Err(_) => {
tracing::error!("Timeout waiting for response from worker {:?}", worker);
}
}
}
results
}
/// Query all workers for the number of new blocks that would be added by a token sequence
pub async fn new_blocks(
pub fn new_blocks(
&self,
token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::NewBlocks {
token_sequence: ts,
resp_tx,
},
None => unreachable!("token_sequence should always be Some for new_blocks"),
})
.await
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().new_blocks(&token_sequence));
}
results
}
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
pub async fn potential_blocks(
pub fn potential_blocks(
&self,
token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts {
Some(ts) => UpdateSequences::PotentialBlocks {
token_sequence: ts,
resp_tx,
},
None => unreachable!("token_sequence should always be Some for potential_blocks"),
})
.await
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(
*entry.key(),
entry.value().potential_blocks(&token_sequence),
);
}
results
}
/// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap
pub async fn potential_blocks_and_tokens(
/// Query all workers for the potential blocks and tokens
pub fn potential_blocks_and_tokens(
&self,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
......@@ -1164,64 +864,28 @@ impl ActiveSequencesMultiWorker {
#[cfg(feature = "bench")]
let start = Instant::now();
#[cfg(feature = "bench")]
let num_workers = self.senders.len();
let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new();
let num_workers = self.workers.len();
// Iterate through all workers, not just those with overlap
// This ensures we properly account for active tokens/blocks on all workers
for sender_entry in self.senders.iter() {
let worker = *sender_entry.key();
let sender = sender_entry.value();
let mut potential_blocks = HashMap::with_capacity(self.workers.len());
let mut potential_tokens = HashMap::with_capacity(self.workers.len());
// Get overlap for this worker (defaults to 0 if not in overlaps)
for entry in self.workers.iter() {
let worker = *entry.key();
let overlap = *overlaps.scores.get(&worker).unwrap_or(&0);
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
receivers.push((worker, resp_rx));
if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens {
token_sequence: token_sequence_shared.clone(),
isl,
overlap,
resp_tx,
}) {
tracing::error!(
"Failed to send potential_tokens command to worker {:?}: {}",
worker,
e
);
}
}
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
// Collect results from all workers
for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
Ok(Ok((blocks, tokens))) => {
let (blocks, tokens) =
entry
.value()
.potential_blocks_and_tokens(token_sequence.as_deref(), isl, overlap);
potential_blocks.insert(worker, blocks);
potential_tokens.insert(worker, tokens);
}
Ok(Err(_)) => {
tracing::error!("Worker {:?} dropped response channel", worker);
}
Err(_) => {
tracing::error!("Timeout waiting for response from worker {:?}", worker);
}
}
}
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
num_workers,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed"
);
......@@ -1231,15 +895,21 @@ impl ActiveSequencesMultiWorker {
}
/// Query all workers for their current number of active blocks
pub async fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx })
.await
pub fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> {
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().active_blocks());
}
results
}
/// Query all workers for their current number of active tokens
pub async fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx })
.await
pub fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> {
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().active_tokens());
}
results
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
......@@ -1252,15 +922,6 @@ impl ActiveSequencesMultiWorker {
}
}
impl Drop for ActiveSequencesMultiWorker {
fn drop(&mut self) {
// Send shutdown to all workers
for entry in self.senders.iter() {
let _ = entry.value().send(UpdateSequences::Shutdown);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
......@@ -1400,8 +1061,8 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let blocks_phase1 = seq_manager_1.active_blocks().await;
let tokens_phase1 = seq_manager_1.active_tokens().await;
let blocks_phase1 = seq_manager_1.active_blocks();
let tokens_phase1 = seq_manager_1.active_tokens();
// Verify that seq_manager_1 sees all requests including request_2 from seq_manager_2
// We now have:
......@@ -1450,8 +1111,8 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_2 to verify everything is empty
let blocks_phase2 = seq_manager_2.active_blocks().await;
let tokens_phase2 = seq_manager_2.active_tokens().await;
let blocks_phase2 = seq_manager_2.active_blocks();
let tokens_phase2 = seq_manager_2.active_tokens();
// Verify phase 2 results - everything should be empty for all 3 workers
let all_workers = vec![
......@@ -1579,7 +1240,7 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let tokens_phase1 = seq_manager_1.active_tokens().await;
let tokens_phase1 = seq_manager_1.active_tokens();
// Verify that seq_manager_1 sees all requests including request_2 from thread 2
let worker_0 = WorkerWithDpRank::from_worker_id(0);
......@@ -1621,7 +1282,7 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_2 to verify everything is empty
let tokens_phase2 = seq_manager_2.active_tokens().await;
let tokens_phase2 = seq_manager_2.active_tokens();
// Verify phase 2 results - everything should be empty
for worker_id in 0..=2 {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment