Unverified Commit 36e8b8c8 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): wake queue on remote lifecycle (#8097)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent a0c7025d
......@@ -30,6 +30,7 @@ where
request_tx: mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker<P>>,
queue: Arc<SchedulerQueue<P, C, S, Sel>>,
queue_updates: watch::Sender<()>,
track_prefill_tokens_default: bool,
worker_type: &'static str,
}
......@@ -107,8 +108,34 @@ where
policy,
prefill_load_estimator,
));
let (queue_updates, _) = watch::channel(());
let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024);
let queue_clone = Arc::clone(&queue);
let queue_remote_updates = Arc::clone(&queue);
let mut remote_state_updates = slots.subscribe_remote_state_changes();
let remote_update_cancel_token = cancellation_token.clone();
let queue_updates_remote = queue_updates.clone();
tokio::spawn(async move {
tracing::trace!("LocalScheduler remote state listener started");
loop {
tokio::select! {
_ = remote_update_cancel_token.cancelled() => {
tracing::trace!("LocalScheduler remote state listener shutting down");
break;
}
result = remote_state_updates.changed() => {
if result.is_err() {
tracing::trace!("LocalScheduler remote state listener shutting down");
break;
}
queue_remote_updates.update().await;
let _ = queue_updates_remote.send(());
}
}
}
});
tokio::spawn(async move {
let mut request_rx = request_rx;
......@@ -140,6 +167,7 @@ where
request_tx,
slots,
queue,
queue_updates,
track_prefill_tokens_default,
worker_type,
}
......@@ -219,6 +247,10 @@ where
self.worker_type
}
pub fn subscribe_queue_updates(&self) -> watch::Receiver<()> {
self.queue_updates.subscribe()
}
pub fn add_output_block(
&self,
request_id: &str,
......@@ -277,15 +309,26 @@ mod tests {
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use tokio::sync::{mpsc, watch};
use super::*;
use crate::protocols::OverlapScores;
use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores};
use crate::scheduling::PrefillLoadEstimator;
use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector;
use crate::sequences::SequenceSubscriber;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
struct TestSequenceSubscriber {
rx: mpsc::UnboundedReceiver<ActiveSequenceEvent>,
}
impl SequenceSubscriber for TestSequenceSubscriber {
async fn next_event(&mut self) -> Option<anyhow::Result<ActiveSequenceEvent>> {
self.rx.recv().await.map(Ok)
}
}
struct FixedPrefillLoadEstimator {
duration: Duration,
}
......@@ -344,6 +387,31 @@ mod tests {
(scheduler, slots, cfg_tx, cancel_token)
}
fn start_replica_sync(
slots: &Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
cancel_token: &CancellationToken,
) -> mpsc::UnboundedSender<ActiveSequenceEvent> {
let (tx, rx) = mpsc::unbounded_channel();
slots.start_replica_sync(TestSequenceSubscriber { rx }, cancel_token.clone());
tx
}
async fn wait_for_pending_count(
scheduler: &Arc<LocalScheduler<NoopSequencePublisher, SimpleWorkerConfig, FcfsPolicy>>,
expected: usize,
) {
tokio::time::timeout(Duration::from_millis(250), async {
loop {
if scheduler.pending_count() == expected {
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
})
.await
.unwrap();
}
#[tokio::test]
async fn test_schedule_books_request_into_active_sequences() {
let mut workers = HashMap::new();
......@@ -472,8 +540,7 @@ mod tests {
})
};
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(scheduler.pending_count(), 1);
wait_for_pending_count(&scheduler, 1).await;
scheduler.mark_prefill_completed("req-1").await.unwrap();
queued.await.unwrap().unwrap();
......@@ -482,6 +549,223 @@ mod tests {
cancel_token.cancel();
}
#[tokio::test]
async fn test_remote_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) =
make_scheduler(workers, Some(0.5), true, None);
let event_tx = start_replica_sync(&slots, &cancel_token);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
wait_for_pending_count(&scheduler, 1).await;
event_tx
.send(ActiveSequenceEvent {
request_id: "req-1".to_string(),
worker: WorkerWithDpRank::new(0, 0),
data: ActiveSequenceEventData::MarkPrefillCompleted,
router_id: 1,
lora_name: None,
})
.unwrap();
tokio::time::timeout(Duration::from_millis(250), async {
queued.await.unwrap().unwrap();
})
.await
.unwrap();
assert_eq!(scheduler.pending_count(), 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_remote_queue_update_notification_fires_after_drain() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) =
make_scheduler(workers, Some(0.5), true, None);
let event_tx = start_replica_sync(&slots, &cancel_token);
let mut queue_updates = scheduler.subscribe_queue_updates();
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
wait_for_pending_count(&scheduler, 1).await;
event_tx
.send(ActiveSequenceEvent {
request_id: "req-1".to_string(),
worker: WorkerWithDpRank::new(0, 0),
data: ActiveSequenceEventData::Free,
router_id: 1,
lora_name: None,
})
.unwrap();
tokio::time::timeout(Duration::from_millis(250), queue_updates.changed())
.await
.unwrap()
.unwrap();
assert_eq!(scheduler.pending_count(), 0);
queued.await.unwrap().unwrap();
cancel_token.cancel();
}
#[tokio::test]
async fn test_remote_free_drains_pending_queue() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) =
make_scheduler(workers, Some(0.5), true, None);
let event_tx = start_replica_sync(&slots, &cancel_token);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
wait_for_pending_count(&scheduler, 1).await;
event_tx
.send(ActiveSequenceEvent {
request_id: "req-1".to_string(),
worker: WorkerWithDpRank::new(0, 0),
data: ActiveSequenceEventData::Free,
router_id: 1,
lora_name: None,
})
.unwrap();
tokio::time::timeout(Duration::from_millis(250), async {
queued.await.unwrap().unwrap();
})
.await
.unwrap();
assert_eq!(scheduler.pending_count(), 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_free_updates_active_state() {
let mut workers = HashMap::new();
......
......@@ -15,6 +15,7 @@ use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::Arc;
use tokio::sync::watch;
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
......@@ -144,6 +145,7 @@ pub struct ActiveSequencesMultiWorker<P: SequencePublisher> {
block_size: usize,
router_id: u64,
publisher: Arc<P>,
remote_state_updates: watch::Sender<()>,
replica_sync: bool,
worker_type: &'static str,
}
......@@ -161,6 +163,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
worker_type: &'static str,
) -> Self {
assert!(block_size > 1, "block_size must be greater than 1");
let (remote_state_updates, _) = watch::channel(());
Self {
workers: RwLock::new(WorkerTable::new(block_size, &dp_range)),
......@@ -169,6 +172,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
block_size,
router_id,
publisher: Arc::new(publisher),
remote_state_updates,
replica_sync,
worker_type,
}
......@@ -193,6 +197,14 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
});
}
/// Subscribe to remote lifecycle updates that were applied through replica sync.
///
/// The queue uses this to react immediately when a peer router frees prompt
/// capacity locally.
pub fn subscribe_remote_state_changes(&self) -> watch::Receiver<()> {
self.remote_state_updates.subscribe()
}
/// Spawn a background task that subscribes to replica-sync events from peer routers
/// and applies them to the local state.
pub fn start_replica_sync<S: SequenceSubscriber + 'static>(
......@@ -235,6 +247,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
// TODO: ActiveSequenceEvent does not carry prompt-load decay timestamps yet.
// Peer routers still approximate decay anchoring with local receive time.
let decay_now = Instant::now();
let mut remote_capacity_changed = false;
match &event.data {
ActiveSequenceEventData::AddRequest {
token_sequence,
......@@ -278,6 +291,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
let table = self.workers.read();
if let Some(&idx) = table.index.get(&worker) {
table.slots[idx].1.write().free(&event.request_id, decay_now);
remote_capacity_changed = true;
}
}
self.request_to_lora.remove(&event.request_id);
......@@ -292,10 +306,15 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
.1
.write()
.mark_prefill_completed(&event.request_id, decay_now);
remote_capacity_changed = true;
}
}
}
}
if remote_capacity_changed {
let _ = self.remote_state_updates.send(());
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Subscription task cancelled");
......
......@@ -93,6 +93,7 @@ where
let metrics_scheduler = Arc::clone(&inner);
let metrics_cancel_token = component.drt().child_token();
let mut queue_updates = inner.subscribe_queue_updates();
tokio::spawn(async move {
let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
ROUTER_QUEUE_METRICS.set_pending(worker_type, metrics_scheduler.pending_count());
......@@ -100,6 +101,13 @@ where
loop {
tokio::select! {
_ = metrics_cancel_token.cancelled() => break,
result = queue_updates.changed() => {
if result.is_err() {
break;
}
ROUTER_QUEUE_METRICS
.set_pending(worker_type, metrics_scheduler.pending_count());
}
_ = recheck_interval.tick() => {
ROUTER_QUEUE_METRICS
.set_pending(worker_type, metrics_scheduler.pending_count());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment