Unverified Commit 36e8b8c8 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): wake queue on remote lifecycle (#8097)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent a0c7025d
...@@ -30,6 +30,7 @@ where ...@@ -30,6 +30,7 @@ where
request_tx: mpsc::Sender<SchedulingRequest>, request_tx: mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker<P>>, slots: Arc<ActiveSequencesMultiWorker<P>>,
queue: Arc<SchedulerQueue<P, C, S, Sel>>, queue: Arc<SchedulerQueue<P, C, S, Sel>>,
queue_updates: watch::Sender<()>,
track_prefill_tokens_default: bool, track_prefill_tokens_default: bool,
worker_type: &'static str, worker_type: &'static str,
} }
...@@ -107,8 +108,34 @@ where ...@@ -107,8 +108,34 @@ where
policy, policy,
prefill_load_estimator, prefill_load_estimator,
)); ));
let (queue_updates, _) = watch::channel(());
let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024); let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024);
let queue_clone = Arc::clone(&queue); let queue_clone = Arc::clone(&queue);
let queue_remote_updates = Arc::clone(&queue);
let mut remote_state_updates = slots.subscribe_remote_state_changes();
let remote_update_cancel_token = cancellation_token.clone();
let queue_updates_remote = queue_updates.clone();
tokio::spawn(async move {
tracing::trace!("LocalScheduler remote state listener started");
loop {
tokio::select! {
_ = remote_update_cancel_token.cancelled() => {
tracing::trace!("LocalScheduler remote state listener shutting down");
break;
}
result = remote_state_updates.changed() => {
if result.is_err() {
tracing::trace!("LocalScheduler remote state listener shutting down");
break;
}
queue_remote_updates.update().await;
let _ = queue_updates_remote.send(());
}
}
}
});
tokio::spawn(async move { tokio::spawn(async move {
let mut request_rx = request_rx; let mut request_rx = request_rx;
...@@ -140,6 +167,7 @@ where ...@@ -140,6 +167,7 @@ where
request_tx, request_tx,
slots, slots,
queue, queue,
queue_updates,
track_prefill_tokens_default, track_prefill_tokens_default,
worker_type, worker_type,
} }
...@@ -219,6 +247,10 @@ where ...@@ -219,6 +247,10 @@ where
self.worker_type self.worker_type
} }
pub fn subscribe_queue_updates(&self) -> watch::Receiver<()> {
self.queue_updates.subscribe()
}
pub fn add_output_block( pub fn add_output_block(
&self, &self,
request_id: &str, request_id: &str,
...@@ -277,15 +309,26 @@ mod tests { ...@@ -277,15 +309,26 @@ mod tests {
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::watch; use tokio::sync::{mpsc, watch};
use super::*; use super::*;
use crate::protocols::OverlapScores; use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores};
use crate::scheduling::PrefillLoadEstimator; use crate::scheduling::PrefillLoadEstimator;
use crate::scheduling::policy::FcfsPolicy; use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector; use crate::scheduling::selector::DefaultWorkerSelector;
use crate::sequences::SequenceSubscriber;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
struct TestSequenceSubscriber {
rx: mpsc::UnboundedReceiver<ActiveSequenceEvent>,
}
impl SequenceSubscriber for TestSequenceSubscriber {
async fn next_event(&mut self) -> Option<anyhow::Result<ActiveSequenceEvent>> {
self.rx.recv().await.map(Ok)
}
}
struct FixedPrefillLoadEstimator { struct FixedPrefillLoadEstimator {
duration: Duration, duration: Duration,
} }
...@@ -344,6 +387,31 @@ mod tests { ...@@ -344,6 +387,31 @@ mod tests {
(scheduler, slots, cfg_tx, cancel_token) (scheduler, slots, cfg_tx, cancel_token)
} }
fn start_replica_sync(
slots: &Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
cancel_token: &CancellationToken,
) -> mpsc::UnboundedSender<ActiveSequenceEvent> {
let (tx, rx) = mpsc::unbounded_channel();
slots.start_replica_sync(TestSequenceSubscriber { rx }, cancel_token.clone());
tx
}
async fn wait_for_pending_count(
scheduler: &Arc<LocalScheduler<NoopSequencePublisher, SimpleWorkerConfig, FcfsPolicy>>,
expected: usize,
) {
tokio::time::timeout(Duration::from_millis(250), async {
loop {
if scheduler.pending_count() == expected {
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
})
.await
.unwrap();
}
#[tokio::test] #[tokio::test]
async fn test_schedule_books_request_into_active_sequences() { async fn test_schedule_books_request_into_active_sequences() {
let mut workers = HashMap::new(); let mut workers = HashMap::new();
...@@ -472,8 +540,7 @@ mod tests { ...@@ -472,8 +540,7 @@ mod tests {
}) })
}; };
tokio::time::sleep(Duration::from_millis(25)).await; wait_for_pending_count(&scheduler, 1).await;
assert_eq!(scheduler.pending_count(), 1);
scheduler.mark_prefill_completed("req-1").await.unwrap(); scheduler.mark_prefill_completed("req-1").await.unwrap();
queued.await.unwrap().unwrap(); queued.await.unwrap().unwrap();
...@@ -482,6 +549,223 @@ mod tests { ...@@ -482,6 +549,223 @@ mod tests {
cancel_token.cancel(); cancel_token.cancel();
} }
#[tokio::test]
async fn test_remote_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) =
make_scheduler(workers, Some(0.5), true, None);
let event_tx = start_replica_sync(&slots, &cancel_token);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
wait_for_pending_count(&scheduler, 1).await;
event_tx
.send(ActiveSequenceEvent {
request_id: "req-1".to_string(),
worker: WorkerWithDpRank::new(0, 0),
data: ActiveSequenceEventData::MarkPrefillCompleted,
router_id: 1,
lora_name: None,
})
.unwrap();
tokio::time::timeout(Duration::from_millis(250), async {
queued.await.unwrap().unwrap();
})
.await
.unwrap();
assert_eq!(scheduler.pending_count(), 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_remote_queue_update_notification_fires_after_drain() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) =
make_scheduler(workers, Some(0.5), true, None);
let event_tx = start_replica_sync(&slots, &cancel_token);
let mut queue_updates = scheduler.subscribe_queue_updates();
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
wait_for_pending_count(&scheduler, 1).await;
event_tx
.send(ActiveSequenceEvent {
request_id: "req-1".to_string(),
worker: WorkerWithDpRank::new(0, 0),
data: ActiveSequenceEventData::Free,
router_id: 1,
lora_name: None,
})
.unwrap();
tokio::time::timeout(Duration::from_millis(250), queue_updates.changed())
.await
.unwrap()
.unwrap();
assert_eq!(scheduler.pending_count(), 0);
queued.await.unwrap().unwrap();
cancel_token.cancel();
}
#[tokio::test]
async fn test_remote_free_drains_pending_queue() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) =
make_scheduler(workers, Some(0.5), true, None);
let event_tx = start_replica_sync(&slots, &cancel_token);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
wait_for_pending_count(&scheduler, 1).await;
event_tx
.send(ActiveSequenceEvent {
request_id: "req-1".to_string(),
worker: WorkerWithDpRank::new(0, 0),
data: ActiveSequenceEventData::Free,
router_id: 1,
lora_name: None,
})
.unwrap();
tokio::time::timeout(Duration::from_millis(250), async {
queued.await.unwrap().unwrap();
})
.await
.unwrap();
assert_eq!(scheduler.pending_count(), 0);
cancel_token.cancel();
}
#[tokio::test] #[tokio::test]
async fn test_free_updates_active_state() { async fn test_free_updates_active_state() {
let mut workers = HashMap::new(); let mut workers = HashMap::new();
......
...@@ -15,6 +15,7 @@ use parking_lot::RwLock; ...@@ -15,6 +15,7 @@ use parking_lot::RwLock;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::watch;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -144,6 +145,7 @@ pub struct ActiveSequencesMultiWorker<P: SequencePublisher> { ...@@ -144,6 +145,7 @@ pub struct ActiveSequencesMultiWorker<P: SequencePublisher> {
block_size: usize, block_size: usize,
router_id: u64, router_id: u64,
publisher: Arc<P>, publisher: Arc<P>,
remote_state_updates: watch::Sender<()>,
replica_sync: bool, replica_sync: bool,
worker_type: &'static str, worker_type: &'static str,
} }
...@@ -161,6 +163,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -161,6 +163,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
worker_type: &'static str, worker_type: &'static str,
) -> Self { ) -> Self {
assert!(block_size > 1, "block_size must be greater than 1"); assert!(block_size > 1, "block_size must be greater than 1");
let (remote_state_updates, _) = watch::channel(());
Self { Self {
workers: RwLock::new(WorkerTable::new(block_size, &dp_range)), workers: RwLock::new(WorkerTable::new(block_size, &dp_range)),
...@@ -169,6 +172,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -169,6 +172,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
block_size, block_size,
router_id, router_id,
publisher: Arc::new(publisher), publisher: Arc::new(publisher),
remote_state_updates,
replica_sync, replica_sync,
worker_type, worker_type,
} }
...@@ -193,6 +197,14 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -193,6 +197,14 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}); });
} }
/// Subscribe to remote lifecycle updates that were applied through replica sync.
///
/// The queue uses this to react immediately when a peer router frees prompt
/// capacity locally.
pub fn subscribe_remote_state_changes(&self) -> watch::Receiver<()> {
self.remote_state_updates.subscribe()
}
/// Spawn a background task that subscribes to replica-sync events from peer routers /// Spawn a background task that subscribes to replica-sync events from peer routers
/// and applies them to the local state. /// and applies them to the local state.
pub fn start_replica_sync<S: SequenceSubscriber + 'static>( pub fn start_replica_sync<S: SequenceSubscriber + 'static>(
...@@ -235,6 +247,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -235,6 +247,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
// TODO: ActiveSequenceEvent does not carry prompt-load decay timestamps yet. // TODO: ActiveSequenceEvent does not carry prompt-load decay timestamps yet.
// Peer routers still approximate decay anchoring with local receive time. // Peer routers still approximate decay anchoring with local receive time.
let decay_now = Instant::now(); let decay_now = Instant::now();
let mut remote_capacity_changed = false;
match &event.data { match &event.data {
ActiveSequenceEventData::AddRequest { ActiveSequenceEventData::AddRequest {
token_sequence, token_sequence,
...@@ -278,6 +291,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -278,6 +291,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
let table = self.workers.read(); let table = self.workers.read();
if let Some(&idx) = table.index.get(&worker) { if let Some(&idx) = table.index.get(&worker) {
table.slots[idx].1.write().free(&event.request_id, decay_now); table.slots[idx].1.write().free(&event.request_id, decay_now);
remote_capacity_changed = true;
} }
} }
self.request_to_lora.remove(&event.request_id); self.request_to_lora.remove(&event.request_id);
...@@ -292,10 +306,15 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -292,10 +306,15 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
.1 .1
.write() .write()
.mark_prefill_completed(&event.request_id, decay_now); .mark_prefill_completed(&event.request_id, decay_now);
remote_capacity_changed = true;
} }
} }
} }
} }
if remote_capacity_changed {
let _ = self.remote_state_updates.send(());
}
} }
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
tracing::debug!("Subscription task cancelled"); tracing::debug!("Subscription task cancelled");
......
...@@ -93,6 +93,7 @@ where ...@@ -93,6 +93,7 @@ where
let metrics_scheduler = Arc::clone(&inner); let metrics_scheduler = Arc::clone(&inner);
let metrics_cancel_token = component.drt().child_token(); let metrics_cancel_token = component.drt().child_token();
let mut queue_updates = inner.subscribe_queue_updates();
tokio::spawn(async move { tokio::spawn(async move {
let mut recheck_interval = tokio::time::interval(Duration::from_secs(60)); let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
ROUTER_QUEUE_METRICS.set_pending(worker_type, metrics_scheduler.pending_count()); ROUTER_QUEUE_METRICS.set_pending(worker_type, metrics_scheduler.pending_count());
...@@ -100,6 +101,13 @@ where ...@@ -100,6 +101,13 @@ where
loop { loop {
tokio::select! { tokio::select! {
_ = metrics_cancel_token.cancelled() => break, _ = metrics_cancel_token.cancelled() => break,
result = queue_updates.changed() => {
if result.is_err() {
break;
}
ROUTER_QUEUE_METRICS
.set_pending(worker_type, metrics_scheduler.pending_count());
}
_ = recheck_interval.tick() => { _ = recheck_interval.tick() => {
ROUTER_QUEUE_METRICS ROUTER_QUEUE_METRICS
.set_pending(worker_type, metrics_scheduler.pending_count()); .set_pending(worker_type, metrics_scheduler.pending_count());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment