Unverified Commit efa89448 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: de-async scheduler read paths and unblock decode output tracking (#6510)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 6fab12be
...@@ -494,14 +494,12 @@ impl KvRouter { ...@@ -494,14 +494,12 @@ impl KvRouter {
self.scheduler.worker_type() self.scheduler.worker_type()
} }
pub async fn add_output_block( pub fn add_output_block(
&self, &self,
request_id: &str, request_id: &str,
decay_fraction: Option<f64>, decay_fraction: Option<f64>,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
self.scheduler self.scheduler.add_output_block(request_id, decay_fraction)
.add_output_block(request_id, decay_fraction)
.await
} }
pub fn block_size(&self) -> u32 { pub fn block_size(&self) -> u32 {
...@@ -541,8 +539,7 @@ impl KvRouter { ...@@ -541,8 +539,7 @@ impl KvRouter {
Ok(self Ok(self
.scheduler .scheduler
.get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores) .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores))
.await)
} }
/// Dump all events from the indexer /// Dump all events from the indexer
......
...@@ -108,7 +108,6 @@ impl RequestGuard { ...@@ -108,7 +108,6 @@ impl RequestGuard {
if let Err(e) = self if let Err(e) = self
.chooser .chooser
.add_output_block(&self.context_id, decay_fraction) .add_output_block(&self.context_id, decay_fraction)
.await
{ {
tracing::warn!( tracing::warn!(
"Failed to add output block for request {}: {e}", "Failed to add output block for request {}: {e}",
......
...@@ -103,7 +103,7 @@ impl SchedulerQueue { ...@@ -103,7 +103,7 @@ impl SchedulerQueue {
return; return;
}; };
if self.all_workers_busy(threshold).await { if self.all_workers_busy(threshold) {
tracing::debug!("all workers busy, queueing request"); tracing::debug!("all workers busy, queueing request");
let entry = self.make_entry(request); let entry = self.make_entry(request);
self.pending.lock().await.push(entry); self.pending.lock().await.push(entry);
...@@ -121,7 +121,7 @@ impl SchedulerQueue { ...@@ -121,7 +121,7 @@ impl SchedulerQueue {
}; };
loop { loop {
if self.all_workers_busy(threshold).await { if self.all_workers_busy(threshold) {
break; break;
} }
let Some(entry) = self.pending.lock().await.pop() else { let Some(entry) = self.pending.lock().await.pop() else {
...@@ -135,14 +135,11 @@ impl SchedulerQueue { ...@@ -135,14 +135,11 @@ impl SchedulerQueue {
/// Run the full scheduling pipeline for a single request: /// Run the full scheduling pipeline for a single request:
/// compute potential load → select worker → respond → book via add_request. /// compute potential load → select worker → respond → book via add_request.
async fn schedule(&self, mut request: SchedulingRequest) { async fn schedule(&self, mut request: SchedulingRequest) {
let (decode_blocks, prefill_tokens) = self let (decode_blocks, prefill_tokens) = self.slots.potential_blocks_and_tokens(
.slots
.potential_blocks_and_tokens(
request.token_seq.clone(), request.token_seq.clone(),
request.isl_tokens, request.isl_tokens,
request.overlaps.clone(), request.overlaps.clone(),
) );
.await;
request.decode_blocks = decode_blocks; request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens; request.prefill_tokens = prefill_tokens;
...@@ -194,8 +191,8 @@ impl SchedulerQueue { ...@@ -194,8 +191,8 @@ impl SchedulerQueue {
/// Check if all workers are busy based on threshold. /// Check if all workers are busy based on threshold.
/// Returns true only if ALL workers exceed the threshold (no worker has capacity). /// Returns true only if ALL workers exceed the threshold (no worker has capacity).
async fn all_workers_busy(&self, threshold: f64) -> bool { fn all_workers_busy(&self, threshold: f64) -> bool {
let active_tokens = self.slots.active_tokens().await; let active_tokens = self.slots.active_tokens();
let configs = self.workers_with_configs.borrow(); let configs = self.workers_with_configs.borrow();
for (&worker_id, config) in configs.iter() { for (&worker_id, config) in configs.iter() {
......
...@@ -272,17 +272,16 @@ impl KvScheduler { ...@@ -272,17 +272,16 @@ impl KvScheduler {
self.slots.worker_type() self.slots.worker_type()
} }
pub async fn add_output_block( pub fn add_output_block(
&self, &self,
request_id: &str, request_id: &str,
decay_fraction: Option<f64>, decay_fraction: Option<f64>,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
self.slots self.slots
.add_output_block(&request_id.to_string(), decay_fraction) .add_output_block(&request_id.to_string(), decay_fraction)
.await
} }
pub async fn get_potential_loads( pub fn get_potential_loads(
&self, &self,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize, isl_tokens: usize,
...@@ -290,8 +289,7 @@ impl KvScheduler { ...@@ -290,8 +289,7 @@ impl KvScheduler {
) -> Vec<PotentialLoad> { ) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) = self let (decode_blocks, prefill_tokens) = self
.slots .slots
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps) .potential_blocks_and_tokens(token_seq, isl_tokens, overlaps);
.await;
// Get all unique WorkerWithDpRank from both hashmaps // Get all unique WorkerWithDpRank from both hashmaps
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new(); let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
......
...@@ -9,12 +9,11 @@ ...@@ -9,12 +9,11 @@
//! //!
//! # Key Components //! # Key Components
//! //!
//! - [`ActiveSequences`]: Single-threaded sequence manager that tracks active requests and their //! - [`ActiveSequences`]: Per-worker sequence manager that tracks active requests and their
//! token sequences, managing shared KV cache blocks efficiently. //! token sequences, managing shared KV cache blocks efficiently.
//! //!
//! - [`ActiveSequencesMultiWorker`]: Multi-threaded extension that distributes sequence management //! - [`ActiveSequencesMultiWorker`]: Multi-worker extension that stores per-worker
//! across multiple worker threads, enabling parallel processing of requests while maintaining //! `ActiveSequences` in a shared `DashMap` for lock-free concurrent access.
//! consistency.
//! //!
//! # Architecture //! # Architecture
//! //!
...@@ -31,7 +30,6 @@ use dynamo_runtime::traits::DistributedRuntimeProvider; ...@@ -31,7 +30,6 @@ use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber}; use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber};
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::rc::{Rc, Weak};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
...@@ -62,9 +60,6 @@ pub enum SequenceError { ...@@ -62,9 +60,6 @@ pub enum SequenceError {
#[error("Failed to publish event: {0}")] #[error("Failed to publish event: {0}")]
PublishFailed(#[from] anyhow::Error), PublishFailed(#[from] anyhow::Error),
#[error("Failed to send command to worker: channel closed")]
WorkerChannelClosed,
} }
/// Duration after which stale requests are forcibly expired (5 minutes) /// Duration after which stale requests are forcibly expired (5 minutes)
...@@ -87,14 +82,14 @@ pub struct SequenceRequest { ...@@ -87,14 +82,14 @@ pub struct SequenceRequest {
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache /// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)] #[derive(Debug, Getters)]
pub struct ActiveSequences { pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<(SequenceHash, Rc<()>)>>, active_seqs: HashMap<RequestId, Vec<(SequenceHash, Arc<()>)>>,
prefill_tokens: HashMap<RequestId, usize>, prefill_tokens: HashMap<RequestId, usize>,
/// Expected output tokens per request (used for resource estimation) /// Expected output tokens per request (used for resource estimation)
expected_output_tokens: HashMap<RequestId, u32>, expected_output_tokens: HashMap<RequestId, u32>,
unique_blocks: HashMap<SequenceHash, Weak<()>>, unique_blocks: HashMap<SequenceHash, std::sync::Weak<()>>,
/// Fractional block counts for blocks that are partially cached /// Fractional block counts for blocks that are partially cached
/// When a block is in both unique_blocks and fractional_blocks, /// When a block is in both unique_blocks and fractional_blocks,
...@@ -133,15 +128,15 @@ impl ActiveSequences { ...@@ -133,15 +128,15 @@ impl ActiveSequences {
} }
} }
fn touch_block(&mut self, block: &SequenceHash) -> Rc<()> { fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
if let Some(weak) = self.unique_blocks.get(block) if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade() && let Some(rc) = weak.upgrade()
{ {
return rc; return rc;
} }
let rc = Rc::new(()); let rc = Arc::new(());
self.unique_blocks.insert(*block, Rc::downgrade(&rc)); self.unique_blocks.insert(*block, Arc::downgrade(&rc));
rc rc
} }
...@@ -177,7 +172,7 @@ impl ActiveSequences { ...@@ -177,7 +172,7 @@ impl ActiveSequences {
for (hash, rc) in blocks { for (hash, rc) in blocks {
// A block with strong_count == 1 means only this request holds a reference // A block with strong_count == 1 means only this request holds a reference
if Rc::strong_count(rc) == 1 { if Arc::strong_count(rc) == 1 {
self.fractional_blocks.insert(*hash, fraction); self.fractional_blocks.insert(*hash, fraction);
} }
} }
...@@ -214,7 +209,7 @@ impl ActiveSequences { ...@@ -214,7 +209,7 @@ impl ActiveSequences {
} }
if let Some(sequence) = token_sequence { if let Some(sequence) = token_sequence {
let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence let sequence_with_refs: Vec<(SequenceHash, Arc<()>)> = sequence
.iter() .iter()
.map(|block| (*block, self.touch_block(block))) .map(|block| (*block, self.touch_block(block)))
.collect(); .collect();
...@@ -370,64 +365,16 @@ impl ActiveSequences { ...@@ -370,64 +365,16 @@ impl ActiveSequences {
} }
} }
enum UpdateSequences { /// Multi-worker extension of ActiveSequences using shared DashMap for lock-free concurrent access
AddRequest {
request_id: RequestId,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
resp_tx: tokio::sync::oneshot::Sender<HashSet<RequestId>>,
},
Free {
request_id: RequestId,
},
MarkPrefillCompleted {
request_id: RequestId,
},
AddOutputBlock {
request_id: RequestId,
decay_fraction: Option<f64>,
resp_tx: tokio::sync::oneshot::Sender<bool>,
},
NewBlocks {
token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: tokio::sync::oneshot::Sender<usize>,
},
PotentialBlocks {
token_sequence: Arc<Vec<SequenceHash>>,
resp_tx: tokio::sync::oneshot::Sender<usize>,
},
PotentialBlocksAndTokens {
token_sequence: Option<Arc<Vec<SequenceHash>>>,
isl: usize,
overlap: u32,
resp_tx: tokio::sync::oneshot::Sender<(usize, usize)>,
},
ActiveBlocks {
resp_tx: tokio::sync::oneshot::Sender<usize>,
},
ActiveTokens {
resp_tx: tokio::sync::oneshot::Sender<usize>,
},
Shutdown,
}
/// Multi-worker extension of ActiveSequences that distributes requests across multiple threads
pub struct ActiveSequencesMultiWorker { pub struct ActiveSequencesMultiWorker {
senders: Arc<DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>>, workers: Arc<DashMap<WorkerWithDpRank, ActiveSequences>>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>, request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
request_to_lora: Arc<DashMap<RequestId, String>>, request_to_lora: Arc<DashMap<RequestId, String>>,
handles: Arc<DashMap<WorkerWithDpRank, std::thread::JoinHandle<()>>>,
block_size: usize, block_size: usize,
component: Component,
router_id: u64, router_id: u64,
/// Publisher for sequence events
event_publisher: EventPublisher, event_publisher: EventPublisher,
/// Publisher for metrics (namespace-scoped) metrics_publisher: Arc<EventPublisher>,
metrics_publisher: EventPublisher,
replica_sync: bool, replica_sync: bool,
/// Worker type for Prometheus metrics labeling ("prefill" or "decode")
worker_type: &'static str, worker_type: &'static str,
} }
...@@ -442,37 +389,30 @@ impl ActiveSequencesMultiWorker { ...@@ -442,37 +389,30 @@ impl ActiveSequencesMultiWorker {
) -> Result<Self> { ) -> Result<Self> {
assert!(block_size > 1, "block_size must be greater than 1"); assert!(block_size > 1, "block_size must be greater than 1");
let senders = Arc::new(DashMap::new()); let workers = Arc::new(DashMap::new());
let handles = Arc::new(DashMap::new());
let request_to_worker = Arc::new(DashMap::new()); let request_to_worker = Arc::new(DashMap::new());
let request_to_lora = Arc::new(DashMap::new()); let request_to_lora = Arc::new(DashMap::new());
// Expand workers by their dp_rank
for (worker_id, config) in workers_with_configs { for (worker_id, config) in workers_with_configs {
let dp_size = config.data_parallel_size; let dp_size = config.data_parallel_size;
for dp_rank in 0..dp_size { for dp_rank in 0..dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank); let worker = WorkerWithDpRank::new(worker_id, dp_rank);
// Create a child cancellation token from the component's runtime workers.insert(worker, ActiveSequences::new(block_size));
let cancel_token = component.drt().runtime().child_token();
let (sender, handle) = Self::start_worker(block_size, cancel_token);
senders.insert(worker, sender);
handles.insert(worker, handle);
} }
} }
let event_publisher = let event_publisher =
EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?; EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?;
let metrics_publisher = let metrics_publisher = Arc::new(
EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?; EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?,
);
let multi_worker = Self { let multi_worker = Self {
senders: senders.clone(), workers: workers.clone(),
request_to_worker: request_to_worker.clone(), request_to_worker: request_to_worker.clone(),
request_to_lora: request_to_lora.clone(), request_to_lora: request_to_lora.clone(),
handles,
block_size, block_size,
component: component.clone(),
event_publisher, event_publisher,
metrics_publisher, metrics_publisher,
router_id, router_id,
...@@ -480,9 +420,8 @@ impl ActiveSequencesMultiWorker { ...@@ -480,9 +420,8 @@ impl ActiveSequencesMultiWorker {
worker_type, worker_type,
}; };
// Start the subscription loop only if replica_sync is enabled
if replica_sync { if replica_sync {
let senders_clone = senders.clone(); let workers_clone = workers.clone();
let request_to_worker_clone = request_to_worker.clone(); let request_to_worker_clone = request_to_worker.clone();
let request_to_lora_clone = request_to_lora.clone(); let request_to_lora_clone = request_to_lora.clone();
let component_clone = component.clone(); let component_clone = component.clone();
...@@ -490,9 +429,8 @@ impl ActiveSequencesMultiWorker { ...@@ -490,9 +429,8 @@ impl ActiveSequencesMultiWorker {
let cancel_token = component.drt().runtime().child_token(); let cancel_token = component.drt().runtime().child_token();
tokio::spawn(async move { tokio::spawn(async move {
// NATS subscription loop
if let Err(e) = Self::subscribe_to_events( if let Err(e) = Self::subscribe_to_events(
senders_clone, workers_clone,
request_to_worker_clone, request_to_worker_clone,
request_to_lora_clone, request_to_lora_clone,
component_clone, component_clone,
...@@ -509,120 +447,9 @@ impl ActiveSequencesMultiWorker { ...@@ -509,120 +447,9 @@ impl ActiveSequencesMultiWorker {
Ok(multi_worker) Ok(multi_worker)
} }
/// Helper method to start a worker task
fn start_worker(
block_size: usize,
cancel_token: CancellationToken,
) -> (
tokio::sync::mpsc::UnboundedSender<UpdateSequences>,
std::thread::JoinHandle<()>,
) {
let (request_tx, request_rx) = tokio::sync::mpsc::unbounded_channel();
let handle = std::thread::spawn(move || {
// Create a single-threaded tokio runtime
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async move {
let mut active_sequences = ActiveSequences::new(block_size);
let mut request_rx = request_rx;
loop {
tokio::select! {
command = request_rx.recv() => {
let Some(command) = command else {
break;
};
match command {
UpdateSequences::AddRequest {
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
resp_tx,
} => {
let removed = active_sequences.add_request(request_id, token_sequence, isl, overlap, expected_output_tokens);
let _ = resp_tx.send(removed);
}
UpdateSequences::Free { request_id } => {
active_sequences.free(&request_id);
}
UpdateSequences::MarkPrefillCompleted { request_id } => {
active_sequences.mark_prefill_completed(&request_id);
}
UpdateSequences::AddOutputBlock {
request_id,
decay_fraction,
resp_tx,
} => {
let success = active_sequences.add_output_block(&request_id, decay_fraction);
let _ = resp_tx.send(success);
}
UpdateSequences::NewBlocks {
token_sequence,
resp_tx,
} => {
let new_blocks = active_sequences.new_blocks(&token_sequence);
let _ = resp_tx.send(new_blocks);
}
UpdateSequences::PotentialBlocks {
token_sequence,
resp_tx,
} => {
let potential_blocks = active_sequences.potential_blocks(&token_sequence);
let _ = resp_tx.send(potential_blocks);
}
UpdateSequences::PotentialBlocksAndTokens {
token_sequence,
isl,
overlap,
resp_tx,
} => {
let potential_tokens = active_sequences.potential_blocks_and_tokens(
token_sequence.as_ref().map(|v| v.as_slice()),
isl,
overlap,
);
let _ = resp_tx.send(potential_tokens);
}
UpdateSequences::ActiveBlocks { resp_tx } => {
let active_blocks = active_sequences.active_blocks();
let _ = resp_tx.send(active_blocks);
}
UpdateSequences::ActiveTokens { resp_tx } => {
let active_tokens = active_sequences.active_tokens();
let _ = resp_tx.send(active_tokens);
}
UpdateSequences::Shutdown => {
break;
}
}
}
// Handle cancellation
_ = cancel_token.cancelled() => {
tracing::debug!("Worker task cancelled");
break;
}
}
}
});
tracing::debug!("ActiveSequences worker task completed");
});
(request_tx, handle)
}
/// Background task to subscribe to active sequence events and update all workers /// Background task to subscribe to active sequence events and update all workers
async fn subscribe_to_events( async fn subscribe_to_events(
senders: Arc< workers: Arc<DashMap<WorkerWithDpRank, ActiveSequences>>,
DashMap<WorkerWithDpRank, tokio::sync::mpsc::UnboundedSender<UpdateSequences>>,
>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>, request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
request_to_lora: Arc<DashMap<RequestId, String>>, request_to_lora: Arc<DashMap<RequestId, String>>,
component: Component, component: Component,
...@@ -635,10 +462,8 @@ impl ActiveSequencesMultiWorker { ...@@ -635,10 +462,8 @@ impl ActiveSequencesMultiWorker {
loop { loop {
tokio::select! { tokio::select! {
// Handle incoming events
result = subscriber.next() => { result = subscriber.next() => {
let Some(result) = result else { let Some(result) = result else {
// Stream ended
break; break;
}; };
...@@ -650,7 +475,6 @@ impl ActiveSequencesMultiWorker { ...@@ -650,7 +475,6 @@ impl ActiveSequencesMultiWorker {
continue; continue;
}; };
// Skip events emitted by itself
if event.router_id == router_id { if event.router_id == router_id {
continue; continue;
} }
...@@ -664,22 +488,18 @@ impl ActiveSequencesMultiWorker { ...@@ -664,22 +488,18 @@ impl ActiveSequencesMultiWorker {
} => { } => {
request_to_worker.insert(event.request_id.clone(), event.worker); request_to_worker.insert(event.request_id.clone(), event.worker);
// Store lora_name mapping if present
if let Some(ref lora_name) = event.lora_name { if let Some(ref lora_name) = event.lora_name {
request_to_lora.insert(event.request_id.clone(), lora_name.clone()); request_to_lora.insert(event.request_id.clone(), lora_name.clone());
} }
if let Some(sender) = senders.get(&event.worker) { if let Some(mut entry) = workers.get_mut(&event.worker) {
// For replicated events, we create a dummy response channel since we don't need to handle expired requests entry.add_request(
let (resp_tx, _) = tokio::sync::oneshot::channel(); event.request_id.clone(),
let _ = sender.send(UpdateSequences::AddRequest { token_sequence.clone(),
request_id: event.request_id.clone(), *isl,
token_sequence: token_sequence.clone(), *overlap,
isl: *isl, *expected_output_tokens,
overlap: *overlap, );
expected_output_tokens: *expected_output_tokens,
resp_tx,
});
} else { } else {
tracing::warn!( tracing::warn!(
"Worker {:?} not found, cannot process AddRequest", "Worker {:?} not found, cannot process AddRequest",
...@@ -689,27 +509,21 @@ impl ActiveSequencesMultiWorker { ...@@ -689,27 +509,21 @@ impl ActiveSequencesMultiWorker {
} }
ActiveSequenceEventData::Free => { ActiveSequenceEventData::Free => {
if let Some((_, worker)) = request_to_worker.remove(&event.request_id) if let Some((_, worker)) = request_to_worker.remove(&event.request_id)
&& let Some(sender) = senders.get(&worker) && let Some(mut entry) = workers.get_mut(&worker)
{ {
let _ = sender.send(UpdateSequences::Free { entry.free(&event.request_id);
request_id: event.request_id.clone(),
});
} }
// Clean up lora_name mapping
request_to_lora.remove(&event.request_id); request_to_lora.remove(&event.request_id);
} }
ActiveSequenceEventData::MarkPrefillCompleted => { ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker) = request_to_worker.get(&event.request_id) if let Some(worker) = request_to_worker.get(&event.request_id)
&& let Some(sender) = senders.get(&*worker) && let Some(mut entry) = workers.get_mut(&*worker)
{ {
let _ = sender.send(UpdateSequences::MarkPrefillCompleted { entry.mark_prefill_completed(&event.request_id);
request_id: event.request_id.clone(),
});
} }
} }
} }
} }
// Handle cancellation
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
tracing::debug!("Subscription task cancelled"); tracing::debug!("Subscription task cancelled");
break; break;
...@@ -723,9 +537,8 @@ impl ActiveSequencesMultiWorker { ...@@ -723,9 +537,8 @@ impl ActiveSequencesMultiWorker {
/// Update the set of workers, adding and removing as needed /// Update the set of workers, adding and removing as needed
pub fn update_workers(&self, new_workers_with_configs: HashMap<u64, ModelRuntimeConfig>) { pub fn update_workers(&self, new_workers_with_configs: HashMap<u64, ModelRuntimeConfig>) {
let current_workers: HashSet<WorkerWithDpRank> = let current_workers: HashSet<WorkerWithDpRank> =
self.senders.iter().map(|entry| *entry.key()).collect(); self.workers.iter().map(|entry| *entry.key()).collect();
// Expand new workers by their dp_rank
let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new(); let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (worker_id, config) in &new_workers_with_configs { for (worker_id, config) in &new_workers_with_configs {
let dp_size = config.data_parallel_size; let dp_size = config.data_parallel_size;
...@@ -740,17 +553,11 @@ impl ActiveSequencesMultiWorker { ...@@ -740,17 +553,11 @@ impl ActiveSequencesMultiWorker {
let workers_to_add: Vec<WorkerWithDpRank> = let workers_to_add: Vec<WorkerWithDpRank> =
new_workers.difference(&current_workers).copied().collect(); new_workers.difference(&current_workers).copied().collect();
// Remove workers (this will naturally remove all dp ranks for a worker_id)
for worker in &workers_to_remove { for worker in &workers_to_remove {
tracing::warn!("Removing worker {:?}", worker); tracing::warn!("Removing worker {:?}", worker);
// Send shutdown command to the worker self.workers.remove(worker);
if let Some((_, sender)) = self.senders.remove(worker) {
let _ = sender.send(UpdateSequences::Shutdown);
}
self.handles.remove(worker);
// Collect request_ids to remove from request_to_lora
let requests_to_remove: Vec<RequestId> = self let requests_to_remove: Vec<RequestId> = self
.request_to_worker .request_to_worker
.iter() .iter()
...@@ -758,26 +565,18 @@ impl ActiveSequencesMultiWorker { ...@@ -758,26 +565,18 @@ impl ActiveSequencesMultiWorker {
.map(|entry| entry.key().clone()) .map(|entry| entry.key().clone())
.collect(); .collect();
// Clean up request_to_worker mappings for this worker
self.request_to_worker self.request_to_worker
.retain(|_request_id, mapped_worker| mapped_worker != worker); .retain(|_request_id, mapped_worker| mapped_worker != worker);
// Clean up request_to_lora mappings for removed requests
for request_id in requests_to_remove { for request_id in requests_to_remove {
self.request_to_lora.remove(&request_id); self.request_to_lora.remove(&request_id);
} }
} }
// Add new workers
for worker in &workers_to_add { for worker in &workers_to_add {
tracing::warn!("Adding worker {:?}", worker); tracing::warn!("Adding worker {:?}", worker);
self.workers
let (sender, handle) = Self::start_worker( .insert(*worker, ActiveSequences::new(self.block_size));
self.block_size,
self.component.drt().runtime().child_token(),
);
self.senders.insert(*worker, sender);
self.handles.insert(*worker, handle);
} }
} }
...@@ -792,15 +591,9 @@ impl ActiveSequencesMultiWorker { ...@@ -792,15 +591,9 @@ impl ActiveSequencesMultiWorker {
lora_name, lora_name,
} = req; } = req;
// Clone the sender upfront so we don't hold the DashMap Ref across if !self.workers.contains_key(&worker) {
// the .await points below. Also eliminates the TOCTOU between return Err(SequenceError::WorkerNotFound { worker });
// contains_key and a later get().unwrap(). }
let sender = self
.senders
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
if let Some(existing_worker) = self.request_to_worker.get(&request_id) { if let Some(existing_worker) = self.request_to_worker.get(&request_id) {
return Err(SequenceError::DuplicateRequest { return Err(SequenceError::DuplicateRequest {
...@@ -809,8 +602,6 @@ impl ActiveSequencesMultiWorker { ...@@ -809,8 +602,6 @@ impl ActiveSequencesMultiWorker {
}); });
} }
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
if self.replica_sync { if self.replica_sync {
let event = ActiveSequenceEvent { let event = ActiveSequenceEvent {
request_id: request_id.clone(), request_id: request_id.clone(),
...@@ -833,38 +624,37 @@ impl ActiveSequencesMultiWorker { ...@@ -833,38 +624,37 @@ impl ActiveSequencesMultiWorker {
self.request_to_lora.insert(request_id.clone(), lora); self.request_to_lora.insert(request_id.clone(), lora);
} }
sender let removed_requests = {
.send(UpdateSequences::AddRequest { let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
entry.add_request(
request_id, request_id,
token_sequence, token_sequence,
isl, isl,
overlap, overlap,
expected_output_tokens, expected_output_tokens,
resp_tx, )
}) };
.map_err(|_| SequenceError::WorkerChannelClosed)?;
let removed_requests = resp_rx
.await
.map_err(|_| SequenceError::WorkerChannelClosed)?;
for expired_id in &removed_requests { for expired_id in &removed_requests {
self.request_to_worker.remove(expired_id); self.request_to_worker.remove(expired_id);
self.request_to_lora.remove(expired_id); self.request_to_lora.remove(expired_id);
} }
self.publish_active_load_for_worker(worker).await; self.publish_active_load_for_worker(worker);
Ok(()) Ok(())
} }
/// Send a command to the worker assigned to a request, optionally publishing /// Send a mutation to the worker assigned to a request, optionally publishing
/// a replica-sync event and cleaning up request mappings afterward. /// a replica-sync event and cleaning up request mappings afterward.
async fn send_to_request_worker( async fn mutate_request_worker(
&self, &self,
request_id: &RequestId, request_id: &RequestId,
event_data: ActiveSequenceEventData, event_data: ActiveSequenceEventData,
command_fn: impl FnOnce(RequestId) -> UpdateSequences, mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId),
remove_mapping: bool, remove_mapping: bool,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
let worker = self let worker = self
...@@ -875,13 +665,6 @@ impl ActiveSequencesMultiWorker { ...@@ -875,13 +665,6 @@ impl ActiveSequencesMultiWorker {
request_id: request_id.clone(), request_id: request_id.clone(),
})?; })?;
let sender = self
.senders
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
if self.replica_sync { if self.replica_sync {
let lora_name = self let lora_name = self
.request_to_lora .request_to_lora
...@@ -898,16 +681,20 @@ impl ActiveSequencesMultiWorker { ...@@ -898,16 +681,20 @@ impl ActiveSequencesMultiWorker {
self.event_publisher.publish(&event).await?; self.event_publisher.publish(&event).await?;
} }
sender {
.send(command_fn(request_id.clone())) let mut entry = self
.map_err(|_| SequenceError::WorkerChannelClosed)?; .workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
mutate_fn(&mut entry, request_id);
}
if remove_mapping { if remove_mapping {
self.request_to_worker.remove(request_id); self.request_to_worker.remove(request_id);
self.request_to_lora.remove(request_id); self.request_to_lora.remove(request_id);
} }
self.publish_active_load_for_worker(worker).await; self.publish_active_load_for_worker(worker);
Ok(()) Ok(())
} }
...@@ -922,10 +709,12 @@ impl ActiveSequencesMultiWorker { ...@@ -922,10 +709,12 @@ impl ActiveSequencesMultiWorker {
return Ok(()); return Ok(());
} }
self.send_to_request_worker( self.mutate_request_worker(
request_id, request_id,
ActiveSequenceEventData::Free, ActiveSequenceEventData::Free,
|rid| UpdateSequences::Free { request_id: rid }, |seqs, rid| {
seqs.free(rid);
},
true, true,
) )
.await .await
...@@ -939,10 +728,12 @@ impl ActiveSequencesMultiWorker { ...@@ -939,10 +728,12 @@ impl ActiveSequencesMultiWorker {
&self, &self,
request_id: &RequestId, request_id: &RequestId,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
self.send_to_request_worker( self.mutate_request_worker(
request_id, request_id,
ActiveSequenceEventData::MarkPrefillCompleted, ActiveSequenceEventData::MarkPrefillCompleted,
|rid| UpdateSequences::MarkPrefillCompleted { request_id: rid }, |seqs, rid| {
seqs.mark_prefill_completed(rid);
},
false, false,
) )
.await .await
...@@ -952,7 +743,9 @@ impl ActiveSequencesMultiWorker { ...@@ -952,7 +743,9 @@ impl ActiveSequencesMultiWorker {
/// ///
/// This is used during generation to track output blocks as they are created. /// This is used during generation to track output blocks as they are created.
/// The decay_fraction represents how "temporary" the block is based on generation progress. /// The decay_fraction represents how "temporary" the block is based on generation progress.
pub async fn add_output_block( // TODO: output blocks are not replicated via replica_sync — add an
// ActiveSequenceEventData variant if cross-instance accuracy matters.
pub fn add_output_block(
&self, &self,
request_id: &RequestId, request_id: &RequestId,
decay_fraction: Option<f64>, decay_fraction: Option<f64>,
...@@ -965,30 +758,13 @@ impl ActiveSequencesMultiWorker { ...@@ -965,30 +758,13 @@ impl ActiveSequencesMultiWorker {
request_id: request_id.clone(), request_id: request_id.clone(),
})?; })?;
// Clone sender upfront to avoid TOCTOU between contains_key and get().unwrap() let success = {
let sender = self let mut entry = self
.senders .workers
.get(&worker) .get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })? .ok_or(SequenceError::WorkerNotFound { worker })?;
.value() entry.add_output_block(request_id, decay_fraction)
.clone(); };
// Create response channel
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
// Send command to worker
sender
.send(UpdateSequences::AddOutputBlock {
request_id: request_id.clone(),
decay_fraction,
resp_tx,
})
.map_err(|_| SequenceError::WorkerChannelClosed)?;
// Wait for response
let success = resp_rx
.await
.map_err(|_| SequenceError::WorkerChannelClosed)?;
if !success { if !success {
return Err(SequenceError::RequestNotFound { return Err(SequenceError::RequestNotFound {
...@@ -996,56 +772,22 @@ impl ActiveSequencesMultiWorker { ...@@ -996,56 +772,22 @@ impl ActiveSequencesMultiWorker {
}); });
} }
// Publish ActiveLoad metrics for this worker self.publish_active_load_for_worker(worker);
self.publish_active_load_for_worker(worker).await;
Ok(()) Ok(())
} }
/// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad /// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) { /// The NATS publish is spawned as a background task to avoid blocking the caller.
// Clone the sender and drop the DashMap Ref immediately. fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
// Holding a Ref across .await points can deadlock: if the task yields let (active_blocks, active_tokens) = {
// and update_workers() needs a write lock on the same shard, the let Some(entry) = self.workers.get(&worker) else {
// runtime thread blocks forever.
let sender = {
let Some(entry) = self.senders.get(&worker) else {
tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad"); tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad");
return; return;
}; };
entry.value().clone() (entry.active_blocks(), entry.active_tokens())
};
// Query active blocks
let (blocks_tx, blocks_rx) = tokio::sync::oneshot::channel();
if sender
.send(UpdateSequences::ActiveBlocks { resp_tx: blocks_tx })
.is_err()
{
tracing::warn!("Failed to send ActiveBlocks query to worker {worker:?}");
return;
}
// Query active tokens
let (tokens_tx, tokens_rx) = tokio::sync::oneshot::channel();
if sender
.send(UpdateSequences::ActiveTokens { resp_tx: tokens_tx })
.is_err()
{
tracing::warn!("Failed to send ActiveTokens query to worker {worker:?}");
return;
}
// Await both responses
let (active_blocks, active_tokens) = match tokio::join!(blocks_rx, tokens_rx) {
(Ok(blocks), Ok(tokens)) => (blocks, tokens),
_ => {
tracing::warn!("Failed to receive active blocks/tokens from worker {worker:?}");
return;
}
}; };
// Update Prometheus gauges directly (router's own bookkeeping)
WORKER_LOAD_METRICS.observe( WORKER_LOAD_METRICS.observe(
worker.worker_id, worker.worker_id,
worker.dp_rank, worker.dp_rank,
...@@ -1054,7 +796,6 @@ impl ActiveSequencesMultiWorker { ...@@ -1054,7 +796,6 @@ impl ActiveSequencesMultiWorker {
active_tokens, active_tokens,
); );
// Also publish ActiveLoad to NATS for other subscribers (if NATS is available)
let active_load = ActiveLoad { let active_load = ActiveLoad {
worker_id: worker.worker_id, worker_id: worker.worker_id,
dp_rank: worker.dp_rank, dp_rank: worker.dp_rank,
...@@ -1062,15 +803,19 @@ impl ActiveSequencesMultiWorker { ...@@ -1062,15 +803,19 @@ impl ActiveSequencesMultiWorker {
active_prefill_tokens: Some(active_tokens as u64), active_prefill_tokens: Some(active_tokens as u64),
}; };
if let Err(e) = self.metrics_publisher.publish(&active_load).await { let publisher = self.metrics_publisher.clone();
// This is expected if NATS is not available - the local gauge update above already succeeded tokio::spawn(async move {
tracing::trace!("Failed to publish ActiveLoad to NATS for worker {worker:?}: {e:?}"); if let Err(e) = publisher.publish(&active_load).await {
tracing::trace!(
"Failed to publish ActiveLoad to NATS for worker {worker:?}: {e:?}"
);
} }
});
} }
/// Get the number of workers /// Get the number of workers
pub fn num_workers(&self) -> usize { pub fn num_workers(&self) -> usize {
self.senders.len() self.workers.len()
} }
/// Get the worker type for this router ("prefill" or "decode"). /// Get the worker type for this router ("prefill" or "decode").
...@@ -1079,80 +824,35 @@ impl ActiveSequencesMultiWorker { ...@@ -1079,80 +824,35 @@ impl ActiveSequencesMultiWorker {
self.worker_type self.worker_type
} }
/// Generic method to query all workers with a given command
async fn query_workers<T: Send + 'static>(
&self,
token_sequence: Option<Vec<SequenceHash>>,
command_fn: impl Fn(
Option<Arc<Vec<SequenceHash>>>,
tokio::sync::oneshot::Sender<T>,
) -> UpdateSequences,
) -> HashMap<WorkerWithDpRank, T> {
let mut results = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new();
// Send queries to all workers in parallel
for entry in self.senders.iter() {
let worker = *entry.key();
let sender = entry.value();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
receivers.push((worker, resp_rx));
if let Err(e) = sender.send(command_fn(token_sequence_shared.clone(), resp_tx)) {
tracing::error!("Failed to send command to worker {:?}: {}", worker, e);
}
}
// Collect results from all workers
for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
Ok(Ok(result)) => {
results.insert(worker, result);
}
Ok(Err(_)) => {
tracing::error!("Worker {:?} dropped response channel", worker);
}
Err(_) => {
tracing::error!("Timeout waiting for response from worker {:?}", worker);
}
}
}
results
}
/// Query all workers for the number of new blocks that would be added by a token sequence /// Query all workers for the number of new blocks that would be added by a token sequence
pub async fn new_blocks( pub fn new_blocks(
&self, &self,
token_sequence: Vec<SequenceHash>, token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> { ) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { let mut results = HashMap::with_capacity(self.workers.len());
Some(ts) => UpdateSequences::NewBlocks { for entry in self.workers.iter() {
token_sequence: ts, results.insert(*entry.key(), entry.value().new_blocks(&token_sequence));
resp_tx, }
}, results
None => unreachable!("token_sequence should always be Some for new_blocks"),
})
.await
} }
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence /// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
pub async fn potential_blocks( pub fn potential_blocks(
&self, &self,
token_sequence: Vec<SequenceHash>, token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> { ) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(Some(token_sequence), |ts, resp_tx| match ts { let mut results = HashMap::with_capacity(self.workers.len());
Some(ts) => UpdateSequences::PotentialBlocks { for entry in self.workers.iter() {
token_sequence: ts, results.insert(
resp_tx, *entry.key(),
}, entry.value().potential_blocks(&token_sequence),
None => unreachable!("token_sequence should always be Some for potential_blocks"), );
}) }
.await results
} }
/// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap /// Query all workers for the potential blocks and tokens
pub async fn potential_blocks_and_tokens( pub fn potential_blocks_and_tokens(
&self, &self,
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize, isl: usize,
...@@ -1164,64 +864,28 @@ impl ActiveSequencesMultiWorker { ...@@ -1164,64 +864,28 @@ impl ActiveSequencesMultiWorker {
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
let start = Instant::now(); let start = Instant::now();
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
let num_workers = self.senders.len(); let num_workers = self.workers.len();
let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
let mut receivers = Vec::new();
// Iterate through all workers, not just those with overlap let mut potential_blocks = HashMap::with_capacity(self.workers.len());
// This ensures we properly account for active tokens/blocks on all workers let mut potential_tokens = HashMap::with_capacity(self.workers.len());
for sender_entry in self.senders.iter() {
let worker = *sender_entry.key();
let sender = sender_entry.value();
// Get overlap for this worker (defaults to 0 if not in overlaps) for entry in self.workers.iter() {
let worker = *entry.key();
let overlap = *overlaps.scores.get(&worker).unwrap_or(&0); let overlap = *overlaps.scores.get(&worker).unwrap_or(&0);
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (blocks, tokens) =
receivers.push((worker, resp_rx)); entry
.value()
if let Err(e) = sender.send(UpdateSequences::PotentialBlocksAndTokens { .potential_blocks_and_tokens(token_sequence.as_deref(), isl, overlap);
token_sequence: token_sequence_shared.clone(),
isl,
overlap,
resp_tx,
}) {
tracing::error!(
"Failed to send potential_tokens command to worker {:?}: {}",
worker,
e
);
}
}
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
// Collect results from all workers
for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
Ok(Ok((blocks, tokens))) => {
potential_blocks.insert(worker, blocks); potential_blocks.insert(worker, blocks);
potential_tokens.insert(worker, tokens); potential_tokens.insert(worker, tokens);
} }
Ok(Err(_)) => {
tracing::error!("Worker {:?} dropped response channel", worker);
}
Err(_) => {
tracing::error!("Timeout waiting for response from worker {:?}", worker);
}
}
}
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
{ {
let total_elapsed = start.elapsed(); let total_elapsed = start.elapsed();
tracing::info!( tracing::info!(
num_workers, num_workers,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64, total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed" "potential_blocks_and_tokens completed"
); );
...@@ -1231,15 +895,21 @@ impl ActiveSequencesMultiWorker { ...@@ -1231,15 +895,21 @@ impl ActiveSequencesMultiWorker {
} }
/// Query all workers for their current number of active blocks /// Query all workers for their current number of active blocks
pub async fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> { pub fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveBlocks { resp_tx }) let mut results = HashMap::with_capacity(self.workers.len());
.await for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().active_blocks());
}
results
} }
/// Query all workers for their current number of active tokens /// Query all workers for their current number of active tokens
pub async fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> { pub fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> {
self.query_workers(None, |_, resp_tx| UpdateSequences::ActiveTokens { resp_tx }) let mut results = HashMap::with_capacity(self.workers.len());
.await for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().active_tokens());
}
results
} }
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> { pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
...@@ -1252,15 +922,6 @@ impl ActiveSequencesMultiWorker { ...@@ -1252,15 +922,6 @@ impl ActiveSequencesMultiWorker {
} }
} }
impl Drop for ActiveSequencesMultiWorker {
fn drop(&mut self) {
// Send shutdown to all workers
for entry in self.senders.iter() {
let _ = entry.value().send(UpdateSequences::Shutdown);
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
...@@ -1400,8 +1061,8 @@ mod tests { ...@@ -1400,8 +1061,8 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2 // Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let blocks_phase1 = seq_manager_1.active_blocks().await; let blocks_phase1 = seq_manager_1.active_blocks();
let tokens_phase1 = seq_manager_1.active_tokens().await; let tokens_phase1 = seq_manager_1.active_tokens();
// Verify that seq_manager_1 sees all requests including request_2 from seq_manager_2 // Verify that seq_manager_1 sees all requests including request_2 from seq_manager_2
// We now have: // We now have:
...@@ -1450,8 +1111,8 @@ mod tests { ...@@ -1450,8 +1111,8 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_2 to verify everything is empty // Query seq_manager_2 to verify everything is empty
let blocks_phase2 = seq_manager_2.active_blocks().await; let blocks_phase2 = seq_manager_2.active_blocks();
let tokens_phase2 = seq_manager_2.active_tokens().await; let tokens_phase2 = seq_manager_2.active_tokens();
// Verify phase 2 results - everything should be empty for all 3 workers // Verify phase 2 results - everything should be empty for all 3 workers
let all_workers = vec![ let all_workers = vec![
...@@ -1579,7 +1240,7 @@ mod tests { ...@@ -1579,7 +1240,7 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2 // Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let tokens_phase1 = seq_manager_1.active_tokens().await; let tokens_phase1 = seq_manager_1.active_tokens();
// Verify that seq_manager_1 sees all requests including request_2 from thread 2 // Verify that seq_manager_1 sees all requests including request_2 from thread 2
let worker_0 = WorkerWithDpRank::from_worker_id(0); let worker_0 = WorkerWithDpRank::from_worker_id(0);
...@@ -1621,7 +1282,7 @@ mod tests { ...@@ -1621,7 +1282,7 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_2 to verify everything is empty // Query seq_manager_2 to verify everything is empty
let tokens_phase2 = seq_manager_2.active_tokens().await; let tokens_phase2 = seq_manager_2.active_tokens();
// Verify phase 2 results - everything should be empty // Verify phase 2 results - everything should be empty
for worker_id in 0..=2 { for worker_id in 0..=2 {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment