Unverified Commit 0cb1d733 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: fold scheduling into queue so backpressure actually works (#6470)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent a9e06960
......@@ -935,9 +935,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.43"
version = "0.4.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118"
checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0"
dependencies = [
"iana-time-zone",
"js-sys",
......@@ -1613,9 +1613,9 @@ dependencies = [
[[package]]
name = "deranged"
version = "0.5.6"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc3dc5ad92c2e2d1c193bbbbdf2ea477cb81331de4f3103f267ca18368b988c4"
checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c"
dependencies = [
"powerfmt",
"serde_core",
......@@ -3566,9 +3566,9 @@ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
[[package]]
name = "jiff"
version = "0.2.20"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c867c356cc096b33f4981825ab281ecba3db0acefe60329f044c1789d94c6543"
checksum = "b3e3d65f018c6ae946ab16e80944b97096ed73c35b221d1c478a6c81d8f57940"
dependencies = [
"jiff-static",
"jiff-tzdb-platform",
......@@ -3581,9 +3581,9 @@ dependencies = [
[[package]]
name = "jiff-static"
version = "0.2.20"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7946b4325269738f270bb55b3c19ab5c5040525f83fd625259422a9d25d9be5"
checksum = "a17c2b211d863c7fde02cbea8a3c1a439b98e109286554f2860bdded7ff83818"
dependencies = [
"proc-macro2",
"quote",
......@@ -3617,9 +3617,9 @@ dependencies = [
[[package]]
name = "js-sys"
version = "0.3.87"
version = "0.3.88"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93f0862381daaec758576dcc22eb7bbf4d7efd67328553f3b45a412a51a3fb21"
checksum = "c7e709f3e3d22866f9c25b3aff01af289b18422cc8b4262fb19103ee80fe513d"
dependencies = [
"once_cell",
"wasm-bindgen",
......@@ -4011,9 +4011,9 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53"
[[package]]
name = "litemap"
......@@ -5792,9 +5792,9 @@ dependencies = [
[[package]]
name = "pulldown-cmark"
version = "0.13.0"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0"
checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6"
dependencies = [
"bitflags 2.11.0",
"memchr",
......@@ -6543,14 +6543,14 @@ dependencies = [
[[package]]
name = "rustix"
version = "1.1.3"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190"
dependencies = [
"bitflags 2.11.0",
"errno",
"libc",
"linux-raw-sys 0.11.0",
"linux-raw-sys 0.12.1",
"windows-sys 0.61.2",
]
......@@ -7045,9 +7045,9 @@ dependencies = [
[[package]]
name = "serial_test"
version = "3.3.1"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d0b343e184fc3b7bb44dff0705fffcf4b3756ba6aff420dddd8b24ca145e555"
checksum = "911bd979bf1070a3f3aa7b691a3b3e9968f339ceeec89e08c280a8a22207a32f"
dependencies = [
"futures-executor",
"futures-util",
......@@ -7060,9 +7060,9 @@ dependencies = [
[[package]]
name = "serial_test_derive"
version = "3.3.1"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f50427f258fb77356e4cd4aa0e87e2bd2c66dbcee41dc405282cae2bfc26c83"
checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9"
dependencies = [
"proc-macro2",
"quote",
......@@ -7428,7 +7428,7 @@ dependencies = [
"fastrand",
"getrandom 0.4.1",
"once_cell",
"rustix 1.1.3",
"rustix 1.1.4",
"windows-sys 0.61.2",
]
......@@ -8626,9 +8626,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen"
version = "0.2.110"
version = "0.2.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1de241cdc66a9d91bd84f097039eb140cdc6eec47e0cdbaf9d932a1dd6c35866"
checksum = "ec1adf1535672f5b7824f817792b1afd731d7e843d2d04ec8f27e8cb51edd8ac"
dependencies = [
"cfg-if 1.0.4",
"once_cell",
......@@ -8639,9 +8639,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.60"
version = "0.4.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a42e96ea38f49b191e08a1bab66c7ffdba24b06f9995b39a9dd60222e5b6f1da"
checksum = "fe88540d1c934c4ec8e6db0afa536876c5441289d7f9f9123d4f065ac1250a6b"
dependencies = [
"cfg-if 1.0.4",
"futures-util",
......@@ -8653,9 +8653,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.110"
version = "0.2.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e12fdf6649048f2e3de6d7d5ff3ced779cdedee0e0baffd7dff5cdfa3abc8a52"
checksum = "19e638317c08b21663aed4d2b9a2091450548954695ff4efa75bff5fa546b3b1"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
......@@ -8663,9 +8663,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.110"
version = "0.2.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e63d1795c565ac3462334c1e396fd46dbf481c40f51f5072c310717bc4fb309"
checksum = "2c64760850114d03d5f65457e96fc988f11f01d38fbaa51b254e4ab5809102af"
dependencies = [
"bumpalo",
"proc-macro2",
......@@ -8676,9 +8676,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.110"
version = "0.2.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f9cdac23a5ce71f6bf9f8824898a501e511892791ea2a0c6b8568c68b9cb53"
checksum = "60eecd4fe26177cfa3339eb00b4a36445889ba3ad37080c2429879718e20ca41"
dependencies = [
"unicode-ident",
]
......@@ -8732,9 +8732,9 @@ dependencies = [
[[package]]
name = "web-sys"
version = "0.3.87"
version = "0.3.88"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2c7c5718134e770ee62af3b6b4a84518ec10101aad610c024b64d6ff29bb1ff"
checksum = "9d6bb20ed2d9572df8584f6dc81d68a41a625cadc6f15999d649a70ce7e3597a"
dependencies = [
"js-sys",
"wasm-bindgen",
......
......@@ -262,10 +262,9 @@ def main():
for tier in TIERS:
logger.info(f" {tier} priority: {len(tier_requests[tier])} requests")
# Use different aiperf random seeds per run so that the generated prompts
# differ, preventing mocker KV cache hits between runs.
baseline_seed = args.seed
priority_seed = args.seed + 1
# Offset hash_ids for the priority run so it starts with a cold KV cache,
# keeping the comparison fair. Same seed for both runs so prompts match.
priority_tier_requests = offset_hash_ids(tier_requests)
# Run 1: Baseline (same split, no priority tagging)
baseline_dir = os.path.join(args.output_dir, "baseline")
......@@ -277,20 +276,20 @@ def main():
baseline_dir,
tag_priority=False,
logger=logger,
seed=baseline_seed,
seed=args.seed,
)
# Run 2: With priority tagging
# Run 2: With priority tagging (offset hash_ids for cold cache)
priority_dir = os.path.join(args.output_dir, "priority")
logger.info("=== Running with priority tagging ===")
run_concurrent_streams(
args,
tier_requests,
priority_tier_requests,
priority_values,
priority_dir,
tag_priority=True,
logger=logger,
seed=priority_seed,
seed=args.seed,
)
# Plot comparison
......
......@@ -2,17 +2,17 @@
// SPDX-License-Identifier: Apache-2.0
use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::collections::BinaryHeap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Notify};
use crate::discovery::RuntimeConfigWatch;
use tokio::sync::Mutex;
use super::WorkerSelector;
use super::protocols::WorkerWithDpRank;
use super::scheduler::SchedulingRequest;
use super::sequence::ActiveSequencesMultiWorker;
use super::scheduler::{SchedulingRequest, SchedulingResponse};
use super::sequence::{ActiveSequencesMultiWorker, SequenceRequest};
use crate::discovery::RuntimeConfigWatch;
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000;
......@@ -45,40 +45,41 @@ impl PartialOrd for QueueEntry {
}
}
/// Queue for managing scheduling requests with interior mutability.
/// Requests are held in `pending` when all workers are busy, and moved to `ready` when capacity frees up.
/// If queueing is disabled (threshold_frac is None), all requests go directly to `ready`.
/// Requests are ordered by effective arrival time: arrival_offset - priority_jump.
/// Queue that gates scheduling requests behind a capacity check.
/// When all workers exceed `threshold_frac` utilisation the request is parked in `pending`.
/// When capacity frees up (`update()`), pending requests are scheduled in priority order.
/// If queueing is disabled (threshold_frac is None), requests are scheduled immediately.
pub struct SchedulerQueue {
pending: Mutex<BinaryHeap<QueueEntry>>,
ready: Mutex<VecDeque<SchedulingRequest>>,
slots: Arc<ActiveSequencesMultiWorker>,
workers_with_configs: RuntimeConfigWatch,
ready_notify: Arc<Notify>,
/// Cached threshold fraction; None means queueing is disabled.
threshold_frac: Option<f64>,
/// Reference instant for computing arrival offsets.
start_time: Instant,
block_size: u32,
selector: Box<dyn WorkerSelector + Send + Sync>,
}
impl SchedulerQueue {
pub fn new(
slots: Arc<ActiveSequencesMultiWorker>,
workers_with_configs: RuntimeConfigWatch,
ready_notify: Arc<Notify>,
threshold_frac: Option<f64>,
block_size: u32,
selector: Box<dyn WorkerSelector + Send + Sync>,
) -> Self {
if let Some(frac) = threshold_frac {
tracing::info!("Router queue enabled with threshold fraction {frac}");
}
Self {
pending: Mutex::new(BinaryHeap::new()),
ready: Mutex::new(VecDeque::new()),
slots,
workers_with_configs,
ready_notify,
threshold_frac,
start_time: Instant::now(),
block_size,
selector,
}
}
......@@ -94,11 +95,11 @@ impl SchedulerQueue {
}
/// Enqueue a new request.
/// If queueing is disabled (threshold not set), fast-track to ready.
/// Otherwise, check busy condition and place in ready or pending.
/// If queueing is disabled or workers have capacity, schedule immediately.
/// Otherwise park in the pending heap.
pub async fn enqueue(&self, request: SchedulingRequest) {
let Some(threshold) = self.threshold_frac else {
self.ready.lock().await.push_back(request);
self.schedule(request).await;
return;
};
......@@ -107,41 +108,87 @@ impl SchedulerQueue {
let entry = self.make_entry(request);
self.pending.lock().await.push(entry);
} else {
self.ready.lock().await.push_back(request);
self.schedule(request).await;
}
}
/// Try to dequeue the highest-priority request from the ready queue.
pub async fn try_dequeue(&self) -> Option<SchedulingRequest> {
self.ready.lock().await.pop_front()
}
/// Called on prefill_complete/free. Re-checks pending requests and moves eligible to ready.
/// Notifies scheduler loop if any requests were moved.
/// Called on prefill_complete/free. Drains pending requests while workers have capacity.
/// Each scheduled request updates active_tokens via add_request, so the busy check
/// sees fresh state on the next iteration.
pub async fn update(&self) {
let Some(threshold) = self.threshold_frac else {
return;
};
let mut moved = false;
loop {
if self.pending.lock().await.is_empty() {
break;
}
if self.all_workers_busy(threshold).await {
break;
}
let entry = self.pending.lock().await.pop();
if let Some(entry) = entry {
tracing::debug!("moving request from pending to ready");
self.ready.lock().await.push_back(entry.request);
moved = true;
} else {
let Some(entry) = self.pending.lock().await.pop() else {
break;
};
tracing::debug!("scheduling request from pending queue");
self.schedule(entry.request).await;
}
}
/// Run the full scheduling pipeline for a single request:
/// compute potential load → select worker → respond → book via add_request.
async fn schedule(&self, mut request: SchedulingRequest) {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens(
request.token_seq.clone(),
request.isl_tokens,
request.overlaps.clone(),
)
.await;
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
let selection = {
let workers = self.workers_with_configs.borrow();
self.selector
.select_worker(&workers, &request, self.block_size)
};
let selection = match selection {
Ok(s) => s,
Err(e) => {
tracing::warn!("scheduling failed: {e}");
request.respond(Err(e));
return;
}
};
request.respond(Ok(SchedulingResponse {
best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks,
}));
if !request.update_states {
return;
}
if moved {
self.ready_notify.notify_one();
let Some(request_id) = request.maybe_request_id else {
tracing::error!("No request_id provided to add_request to the slot tracker");
return;
};
if let Err(e) = self
.slots
.add_request(SequenceRequest {
request_id: request_id.clone(),
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
expected_output_tokens: None,
worker: selection.worker,
lora_name: request.lora_name.clone(),
})
.await
{
tracing::warn!("Failed to add request {request_id}: {e}");
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::queue::SchedulerQueue;
use super::sequence::{ActiveSequencesMultiWorker, SequenceError, SequenceRequest};
use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
......@@ -13,14 +19,6 @@ use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use tokio::sync::Notify;
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::queue::SchedulerQueue;
use super::sequence::{ActiveSequencesMultiWorker, SequenceError, SequenceRequest};
use dynamo_tokens::SequenceHash;
......@@ -65,20 +63,17 @@ pub struct SchedulingRequest {
pub lora_name: Option<String>,
/// Priority jump in seconds; decreases effective arrival time in the queue.
pub priority_jump: f64,
// Option to take it out to send the response without moving the struct
resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>,
}
impl SchedulingRequest {
pub fn respond(&mut self, response: SchedulingResponse) {
// Changed to &mut self
if let Some(tx) = self.resp_tx.take() {
// Use take() to extract the sender
if tx.send(response).is_err() {
tracing::error!("failed to send response to requestor");
}
} else {
pub fn respond(&mut self, result: Result<SchedulingResponse, KvSchedulerError>) {
let Some(tx) = self.resp_tx.take() else {
tracing::error!("respond called multiple times on same request");
return;
};
if tx.send(result).is_err() {
tracing::error!("failed to send response to requestor");
}
}
}
......@@ -150,29 +145,25 @@ impl KvScheduler {
}
});
let slots_clone = slots.clone();
let scheduler_rx = workers_with_configs.clone();
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token();
// Create queue with shared notify for waking the scheduler loop
let ready_notify = Arc::new(Notify::new());
let queue = Arc::new(SchedulerQueue::new(
slots.clone(),
workers_with_configs.clone(),
ready_notify.clone(),
kv_router_config.router_queue_threshold,
block_size,
selector,
));
let queue_clone = queue.clone();
// Background task to handle scheduling requests
// Background task: receive requests and periodically recheck pending
tokio::spawn(async move {
let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
tracing::trace!("scheduler background task started");
loop {
// Use select! to wait on: new request, ready_notify, periodic recheck, or cancellation
tokio::select! {
_ = scheduler_cancel_token.cancelled() => {
tracing::trace!("scheduler background task shutting down");
......@@ -186,74 +177,10 @@ impl KvScheduler {
tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await;
}
_ = ready_notify.notified() => {
// Woken by update() after prefill_complete/free - just continue to drain ready queue
}
_ = recheck_interval.tick() => {
// Periodic recheck to prevent requests stuck in pending
queue_clone.update().await;
}
}
// Drain ALL ready requests (each iteration uses fresh slot state)
while let Some(mut request) = queue_clone.try_dequeue().await {
let (decode_blocks, prefill_tokens) = slots_clone
.potential_blocks_and_tokens(
request.token_seq.clone(),
request.isl_tokens,
request.overlaps.clone(),
)
.await;
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
// Read the current workers configuration from watch receiver
let workers: HashMap<WorkerId, ModelRuntimeConfig> =
scheduler_rx.borrow().clone();
match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => {
let response = SchedulingResponse {
best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks,
};
request.respond(response);
// Skip state update if not requested
if !request.update_states {
continue;
}
let Some(request_id) = request.maybe_request_id else {
tracing::error!(
"No request_id provided to add_request to the slot tracker"
);
continue;
};
if let Err(e) = slots_clone
.add_request(SequenceRequest {
request_id: request_id.clone(),
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
expected_output_tokens: None,
worker: selection.worker,
lora_name: request.lora_name.clone(),
})
.await
{
tracing::warn!("Failed to add request {request_id}: {e}");
}
}
Err(KvSchedulerError::NoEndpoints) => {
tracing::warn!("no endpoints available, dropping request");
}
Err(e) => {
tracing::error!("error scheduling request: {:?}", e);
}
}
}
}
tracing::trace!("background endpoint subscriber shutting down");
......@@ -306,7 +233,7 @@ impl KvScheduler {
let response = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
.map_err(|_| KvSchedulerError::SubscriberShutdown)??;
#[cfg(feature = "bench")]
let total_elapsed = start.elapsed();
......@@ -430,25 +357,14 @@ fn softmax_sample(
// All values are the same, uniform probability
vec![1.0 / keys.len() as f64; keys.len()]
} else {
// Normalize values
let normalized: Vec<_> = values
.iter()
.map(|&v| {
// Lower is better, so negate
// Note we don't need to do actual min-max norm here, just off by an offset
let norm = v / (max_val - min_val);
-norm
})
.collect();
// Apply temperature and softmax
let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();
// Fused normalize → negate → scale → exp, then normalize probabilities
let range = max_val - min_val;
let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect();
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
let sum_exp: f64 = exp_values.iter().sum();
exp_values.iter().map(|&v| v / sum_exp).collect()
let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
let sum: f64 = probs.iter().sum();
probs.iter_mut().for_each(|p| *p /= sum);
probs
};
// Sample from the probability distribution
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment