Unverified Commit 0cb1d733 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: fold scheduling into queue so backpressure actually works (#6470)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent a9e06960
...@@ -935,9 +935,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" ...@@ -935,9 +935,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]] [[package]]
name = "chrono" name = "chrono"
version = "0.4.43" version = "0.4.44"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0"
dependencies = [ dependencies = [
"iana-time-zone", "iana-time-zone",
"js-sys", "js-sys",
...@@ -1613,9 +1613,9 @@ dependencies = [ ...@@ -1613,9 +1613,9 @@ dependencies = [
[[package]] [[package]]
name = "deranged" name = "deranged"
version = "0.5.6" version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc3dc5ad92c2e2d1c193bbbbdf2ea477cb81331de4f3103f267ca18368b988c4" checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c"
dependencies = [ dependencies = [
"powerfmt", "powerfmt",
"serde_core", "serde_core",
...@@ -3566,9 +3566,9 @@ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" ...@@ -3566,9 +3566,9 @@ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
[[package]] [[package]]
name = "jiff" name = "jiff"
version = "0.2.20" version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c867c356cc096b33f4981825ab281ecba3db0acefe60329f044c1789d94c6543" checksum = "b3e3d65f018c6ae946ab16e80944b97096ed73c35b221d1c478a6c81d8f57940"
dependencies = [ dependencies = [
"jiff-static", "jiff-static",
"jiff-tzdb-platform", "jiff-tzdb-platform",
...@@ -3581,9 +3581,9 @@ dependencies = [ ...@@ -3581,9 +3581,9 @@ dependencies = [
[[package]] [[package]]
name = "jiff-static" name = "jiff-static"
version = "0.2.20" version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7946b4325269738f270bb55b3c19ab5c5040525f83fd625259422a9d25d9be5" checksum = "a17c2b211d863c7fde02cbea8a3c1a439b98e109286554f2860bdded7ff83818"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -3617,9 +3617,9 @@ dependencies = [ ...@@ -3617,9 +3617,9 @@ dependencies = [
[[package]] [[package]]
name = "js-sys" name = "js-sys"
version = "0.3.87" version = "0.3.88"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93f0862381daaec758576dcc22eb7bbf4d7efd67328553f3b45a412a51a3fb21" checksum = "c7e709f3e3d22866f9c25b3aff01af289b18422cc8b4262fb19103ee80fe513d"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"wasm-bindgen", "wasm-bindgen",
...@@ -4011,9 +4011,9 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" ...@@ -4011,9 +4011,9 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.11.0" version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53"
[[package]] [[package]]
name = "litemap" name = "litemap"
...@@ -5792,9 +5792,9 @@ dependencies = [ ...@@ -5792,9 +5792,9 @@ dependencies = [
[[package]] [[package]]
name = "pulldown-cmark" name = "pulldown-cmark"
version = "0.13.0" version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0" checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6"
dependencies = [ dependencies = [
"bitflags 2.11.0", "bitflags 2.11.0",
"memchr", "memchr",
...@@ -6543,14 +6543,14 @@ dependencies = [ ...@@ -6543,14 +6543,14 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "1.1.3" version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190"
dependencies = [ dependencies = [
"bitflags 2.11.0", "bitflags 2.11.0",
"errno", "errno",
"libc", "libc",
"linux-raw-sys 0.11.0", "linux-raw-sys 0.12.1",
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
...@@ -7045,9 +7045,9 @@ dependencies = [ ...@@ -7045,9 +7045,9 @@ dependencies = [
[[package]] [[package]]
name = "serial_test" name = "serial_test"
version = "3.3.1" version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d0b343e184fc3b7bb44dff0705fffcf4b3756ba6aff420dddd8b24ca145e555" checksum = "911bd979bf1070a3f3aa7b691a3b3e9968f339ceeec89e08c280a8a22207a32f"
dependencies = [ dependencies = [
"futures-executor", "futures-executor",
"futures-util", "futures-util",
...@@ -7060,9 +7060,9 @@ dependencies = [ ...@@ -7060,9 +7060,9 @@ dependencies = [
[[package]] [[package]]
name = "serial_test_derive" name = "serial_test_derive"
version = "3.3.1" version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f50427f258fb77356e4cd4aa0e87e2bd2c66dbcee41dc405282cae2bfc26c83" checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -7428,7 +7428,7 @@ dependencies = [ ...@@ -7428,7 +7428,7 @@ dependencies = [
"fastrand", "fastrand",
"getrandom 0.4.1", "getrandom 0.4.1",
"once_cell", "once_cell",
"rustix 1.1.3", "rustix 1.1.4",
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
...@@ -8626,9 +8626,9 @@ dependencies = [ ...@@ -8626,9 +8626,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen" name = "wasm-bindgen"
version = "0.2.110" version = "0.2.111"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1de241cdc66a9d91bd84f097039eb140cdc6eec47e0cdbaf9d932a1dd6c35866" checksum = "ec1adf1535672f5b7824f817792b1afd731d7e843d2d04ec8f27e8cb51edd8ac"
dependencies = [ dependencies = [
"cfg-if 1.0.4", "cfg-if 1.0.4",
"once_cell", "once_cell",
...@@ -8639,9 +8639,9 @@ dependencies = [ ...@@ -8639,9 +8639,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-futures" name = "wasm-bindgen-futures"
version = "0.4.60" version = "0.4.61"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a42e96ea38f49b191e08a1bab66c7ffdba24b06f9995b39a9dd60222e5b6f1da" checksum = "fe88540d1c934c4ec8e6db0afa536876c5441289d7f9f9123d4f065ac1250a6b"
dependencies = [ dependencies = [
"cfg-if 1.0.4", "cfg-if 1.0.4",
"futures-util", "futures-util",
...@@ -8653,9 +8653,9 @@ dependencies = [ ...@@ -8653,9 +8653,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-macro" name = "wasm-bindgen-macro"
version = "0.2.110" version = "0.2.111"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e12fdf6649048f2e3de6d7d5ff3ced779cdedee0e0baffd7dff5cdfa3abc8a52" checksum = "19e638317c08b21663aed4d2b9a2091450548954695ff4efa75bff5fa546b3b1"
dependencies = [ dependencies = [
"quote", "quote",
"wasm-bindgen-macro-support", "wasm-bindgen-macro-support",
...@@ -8663,9 +8663,9 @@ dependencies = [ ...@@ -8663,9 +8663,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-macro-support" name = "wasm-bindgen-macro-support"
version = "0.2.110" version = "0.2.111"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e63d1795c565ac3462334c1e396fd46dbf481c40f51f5072c310717bc4fb309" checksum = "2c64760850114d03d5f65457e96fc988f11f01d38fbaa51b254e4ab5809102af"
dependencies = [ dependencies = [
"bumpalo", "bumpalo",
"proc-macro2", "proc-macro2",
...@@ -8676,9 +8676,9 @@ dependencies = [ ...@@ -8676,9 +8676,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-shared" name = "wasm-bindgen-shared"
version = "0.2.110" version = "0.2.111"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f9cdac23a5ce71f6bf9f8824898a501e511892791ea2a0c6b8568c68b9cb53" checksum = "60eecd4fe26177cfa3339eb00b4a36445889ba3ad37080c2429879718e20ca41"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
...@@ -8732,9 +8732,9 @@ dependencies = [ ...@@ -8732,9 +8732,9 @@ dependencies = [
[[package]] [[package]]
name = "web-sys" name = "web-sys"
version = "0.3.87" version = "0.3.88"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2c7c5718134e770ee62af3b6b4a84518ec10101aad610c024b64d6ff29bb1ff" checksum = "9d6bb20ed2d9572df8584f6dc81d68a41a625cadc6f15999d649a70ce7e3597a"
dependencies = [ dependencies = [
"js-sys", "js-sys",
"wasm-bindgen", "wasm-bindgen",
......
...@@ -262,10 +262,9 @@ def main(): ...@@ -262,10 +262,9 @@ def main():
for tier in TIERS: for tier in TIERS:
logger.info(f" {tier} priority: {len(tier_requests[tier])} requests") logger.info(f" {tier} priority: {len(tier_requests[tier])} requests")
# Use different aiperf random seeds per run so that the generated prompts # Offset hash_ids for the priority run so it starts with a cold KV cache,
# differ, preventing mocker KV cache hits between runs. # keeping the comparison fair. Same seed for both runs so prompts match.
baseline_seed = args.seed priority_tier_requests = offset_hash_ids(tier_requests)
priority_seed = args.seed + 1
# Run 1: Baseline (same split, no priority tagging) # Run 1: Baseline (same split, no priority tagging)
baseline_dir = os.path.join(args.output_dir, "baseline") baseline_dir = os.path.join(args.output_dir, "baseline")
...@@ -277,20 +276,20 @@ def main(): ...@@ -277,20 +276,20 @@ def main():
baseline_dir, baseline_dir,
tag_priority=False, tag_priority=False,
logger=logger, logger=logger,
seed=baseline_seed, seed=args.seed,
) )
# Run 2: With priority tagging # Run 2: With priority tagging (offset hash_ids for cold cache)
priority_dir = os.path.join(args.output_dir, "priority") priority_dir = os.path.join(args.output_dir, "priority")
logger.info("=== Running with priority tagging ===") logger.info("=== Running with priority tagging ===")
run_concurrent_streams( run_concurrent_streams(
args, args,
tier_requests, priority_tier_requests,
priority_values, priority_values,
priority_dir, priority_dir,
tag_priority=True, tag_priority=True,
logger=logger, logger=logger,
seed=priority_seed, seed=args.seed,
) )
# Plot comparison # Plot comparison
......
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::cmp::Ordering; use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque}; use std::collections::BinaryHeap;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Notify}; use tokio::sync::Mutex;
use crate::discovery::RuntimeConfigWatch;
use super::WorkerSelector;
use super::protocols::WorkerWithDpRank; use super::protocols::WorkerWithDpRank;
use super::scheduler::SchedulingRequest; use super::scheduler::{SchedulingRequest, SchedulingResponse};
use super::sequence::ActiveSequencesMultiWorker; use super::sequence::{ActiveSequencesMultiWorker, SequenceRequest};
use crate::discovery::RuntimeConfigWatch;
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker) /// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000; const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000;
...@@ -45,40 +45,41 @@ impl PartialOrd for QueueEntry { ...@@ -45,40 +45,41 @@ impl PartialOrd for QueueEntry {
} }
} }
/// Queue for managing scheduling requests with interior mutability. /// Queue that gates scheduling requests behind a capacity check.
/// Requests are held in `pending` when all workers are busy, and moved to `ready` when capacity frees up. /// When all workers exceed `threshold_frac` utilisation the request is parked in `pending`.
/// If queueing is disabled (threshold_frac is None), all requests go directly to `ready`. /// When capacity frees up (`update()`), pending requests are scheduled in priority order.
/// Requests are ordered by effective arrival time: arrival_offset - priority_jump. /// If queueing is disabled (threshold_frac is None), requests are scheduled immediately.
pub struct SchedulerQueue { pub struct SchedulerQueue {
pending: Mutex<BinaryHeap<QueueEntry>>, pending: Mutex<BinaryHeap<QueueEntry>>,
ready: Mutex<VecDeque<SchedulingRequest>>,
slots: Arc<ActiveSequencesMultiWorker>, slots: Arc<ActiveSequencesMultiWorker>,
workers_with_configs: RuntimeConfigWatch, workers_with_configs: RuntimeConfigWatch,
ready_notify: Arc<Notify>,
/// Cached threshold fraction; None means queueing is disabled. /// Cached threshold fraction; None means queueing is disabled.
threshold_frac: Option<f64>, threshold_frac: Option<f64>,
/// Reference instant for computing arrival offsets. /// Reference instant for computing arrival offsets.
start_time: Instant, start_time: Instant,
block_size: u32,
selector: Box<dyn WorkerSelector + Send + Sync>,
} }
impl SchedulerQueue { impl SchedulerQueue {
pub fn new( pub fn new(
slots: Arc<ActiveSequencesMultiWorker>, slots: Arc<ActiveSequencesMultiWorker>,
workers_with_configs: RuntimeConfigWatch, workers_with_configs: RuntimeConfigWatch,
ready_notify: Arc<Notify>,
threshold_frac: Option<f64>, threshold_frac: Option<f64>,
block_size: u32,
selector: Box<dyn WorkerSelector + Send + Sync>,
) -> Self { ) -> Self {
if let Some(frac) = threshold_frac { if let Some(frac) = threshold_frac {
tracing::info!("Router queue enabled with threshold fraction {frac}"); tracing::info!("Router queue enabled with threshold fraction {frac}");
} }
Self { Self {
pending: Mutex::new(BinaryHeap::new()), pending: Mutex::new(BinaryHeap::new()),
ready: Mutex::new(VecDeque::new()),
slots, slots,
workers_with_configs, workers_with_configs,
ready_notify,
threshold_frac, threshold_frac,
start_time: Instant::now(), start_time: Instant::now(),
block_size,
selector,
} }
} }
...@@ -94,11 +95,11 @@ impl SchedulerQueue { ...@@ -94,11 +95,11 @@ impl SchedulerQueue {
} }
/// Enqueue a new request. /// Enqueue a new request.
/// If queueing is disabled (threshold not set), fast-track to ready. /// If queueing is disabled or workers have capacity, schedule immediately.
/// Otherwise, check busy condition and place in ready or pending. /// Otherwise park in the pending heap.
pub async fn enqueue(&self, request: SchedulingRequest) { pub async fn enqueue(&self, request: SchedulingRequest) {
let Some(threshold) = self.threshold_frac else { let Some(threshold) = self.threshold_frac else {
self.ready.lock().await.push_back(request); self.schedule(request).await;
return; return;
}; };
...@@ -107,41 +108,87 @@ impl SchedulerQueue { ...@@ -107,41 +108,87 @@ impl SchedulerQueue {
let entry = self.make_entry(request); let entry = self.make_entry(request);
self.pending.lock().await.push(entry); self.pending.lock().await.push(entry);
} else { } else {
self.ready.lock().await.push_back(request); self.schedule(request).await;
} }
} }
/// Try to dequeue the highest-priority request from the ready queue. /// Called on prefill_complete/free. Drains pending requests while workers have capacity.
pub async fn try_dequeue(&self) -> Option<SchedulingRequest> { /// Each scheduled request updates active_tokens via add_request, so the busy check
self.ready.lock().await.pop_front() /// sees fresh state on the next iteration.
}
/// Called on prefill_complete/free. Re-checks pending requests and moves eligible to ready.
/// Notifies scheduler loop if any requests were moved.
pub async fn update(&self) { pub async fn update(&self) {
let Some(threshold) = self.threshold_frac else { let Some(threshold) = self.threshold_frac else {
return; return;
}; };
let mut moved = false;
loop { loop {
if self.pending.lock().await.is_empty() {
break;
}
if self.all_workers_busy(threshold).await { if self.all_workers_busy(threshold).await {
break; break;
} }
let entry = self.pending.lock().await.pop(); let Some(entry) = self.pending.lock().await.pop() else {
if let Some(entry) = entry {
tracing::debug!("moving request from pending to ready");
self.ready.lock().await.push_back(entry.request);
moved = true;
} else {
break; break;
};
tracing::debug!("scheduling request from pending queue");
self.schedule(entry.request).await;
}
}
/// Run the full scheduling pipeline for a single request:
/// compute potential load → select worker → respond → book via add_request.
async fn schedule(&self, mut request: SchedulingRequest) {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens(
request.token_seq.clone(),
request.isl_tokens,
request.overlaps.clone(),
)
.await;
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
let selection = {
let workers = self.workers_with_configs.borrow();
self.selector
.select_worker(&workers, &request, self.block_size)
};
let selection = match selection {
Ok(s) => s,
Err(e) => {
tracing::warn!("scheduling failed: {e}");
request.respond(Err(e));
return;
} }
};
request.respond(Ok(SchedulingResponse {
best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks,
}));
if !request.update_states {
return;
} }
if moved {
self.ready_notify.notify_one(); let Some(request_id) = request.maybe_request_id else {
tracing::error!("No request_id provided to add_request to the slot tracker");
return;
};
if let Err(e) = self
.slots
.add_request(SequenceRequest {
request_id: request_id.clone(),
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
expected_output_tokens: None,
worker: selection.worker,
lora_name: request.lora_name.clone(),
})
.await
{
tracing::warn!("Failed to add request {request_id}: {e}");
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::queue::SchedulerQueue;
use super::sequence::{ActiveSequencesMultiWorker, SequenceError, SequenceRequest};
use crate::discovery::RuntimeConfigWatch; use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result; use anyhow::Result;
...@@ -13,14 +19,6 @@ use std::sync::Arc; ...@@ -13,14 +19,6 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
use std::time::Instant; use std::time::Instant;
use tokio::sync::Notify;
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::queue::SchedulerQueue;
use super::sequence::{ActiveSequencesMultiWorker, SequenceError, SequenceRequest};
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
...@@ -65,20 +63,17 @@ pub struct SchedulingRequest { ...@@ -65,20 +63,17 @@ pub struct SchedulingRequest {
pub lora_name: Option<String>, pub lora_name: Option<String>,
/// Priority jump in seconds; decreases effective arrival time in the queue. /// Priority jump in seconds; decreases effective arrival time in the queue.
pub priority_jump: f64, pub priority_jump: f64,
// Option to take it out to send the response without moving the struct resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>,
resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
} }
impl SchedulingRequest { impl SchedulingRequest {
pub fn respond(&mut self, response: SchedulingResponse) { pub fn respond(&mut self, result: Result<SchedulingResponse, KvSchedulerError>) {
// Changed to &mut self let Some(tx) = self.resp_tx.take() else {
if let Some(tx) = self.resp_tx.take() {
// Use take() to extract the sender
if tx.send(response).is_err() {
tracing::error!("failed to send response to requestor");
}
} else {
tracing::error!("respond called multiple times on same request"); tracing::error!("respond called multiple times on same request");
return;
};
if tx.send(result).is_err() {
tracing::error!("failed to send response to requestor");
} }
} }
} }
...@@ -150,29 +145,25 @@ impl KvScheduler { ...@@ -150,29 +145,25 @@ impl KvScheduler {
} }
}); });
let slots_clone = slots.clone();
let scheduler_rx = workers_with_configs.clone();
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024); let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token(); let scheduler_cancel_token = component.drt().primary_token();
// Create queue with shared notify for waking the scheduler loop
let ready_notify = Arc::new(Notify::new());
let queue = Arc::new(SchedulerQueue::new( let queue = Arc::new(SchedulerQueue::new(
slots.clone(), slots.clone(),
workers_with_configs.clone(), workers_with_configs.clone(),
ready_notify.clone(),
kv_router_config.router_queue_threshold, kv_router_config.router_queue_threshold,
block_size,
selector,
)); ));
let queue_clone = queue.clone(); let queue_clone = queue.clone();
// Background task to handle scheduling requests // Background task: receive requests and periodically recheck pending
tokio::spawn(async move { tokio::spawn(async move {
let mut request_rx = request_rx; let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(Duration::from_secs(60)); let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
tracing::trace!("scheduler background task started"); tracing::trace!("scheduler background task started");
loop { loop {
// Use select! to wait on: new request, ready_notify, periodic recheck, or cancellation
tokio::select! { tokio::select! {
_ = scheduler_cancel_token.cancelled() => { _ = scheduler_cancel_token.cancelled() => {
tracing::trace!("scheduler background task shutting down"); tracing::trace!("scheduler background task shutting down");
...@@ -186,74 +177,10 @@ impl KvScheduler { ...@@ -186,74 +177,10 @@ impl KvScheduler {
tracing::trace!("received request to be scheduled"); tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await; queue_clone.enqueue(request).await;
} }
_ = ready_notify.notified() => {
// Woken by update() after prefill_complete/free - just continue to drain ready queue
}
_ = recheck_interval.tick() => { _ = recheck_interval.tick() => {
// Periodic recheck to prevent requests stuck in pending
queue_clone.update().await; queue_clone.update().await;
} }
} }
// Drain ALL ready requests (each iteration uses fresh slot state)
while let Some(mut request) = queue_clone.try_dequeue().await {
let (decode_blocks, prefill_tokens) = slots_clone
.potential_blocks_and_tokens(
request.token_seq.clone(),
request.isl_tokens,
request.overlaps.clone(),
)
.await;
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
// Read the current workers configuration from watch receiver
let workers: HashMap<WorkerId, ModelRuntimeConfig> =
scheduler_rx.borrow().clone();
match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => {
let response = SchedulingResponse {
best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks,
};
request.respond(response);
// Skip state update if not requested
if !request.update_states {
continue;
}
let Some(request_id) = request.maybe_request_id else {
tracing::error!(
"No request_id provided to add_request to the slot tracker"
);
continue;
};
if let Err(e) = slots_clone
.add_request(SequenceRequest {
request_id: request_id.clone(),
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
expected_output_tokens: None,
worker: selection.worker,
lora_name: request.lora_name.clone(),
})
.await
{
tracing::warn!("Failed to add request {request_id}: {e}");
}
}
Err(KvSchedulerError::NoEndpoints) => {
tracing::warn!("no endpoints available, dropping request");
}
Err(e) => {
tracing::error!("error scheduling request: {:?}", e);
}
}
}
} }
tracing::trace!("background endpoint subscriber shutting down"); tracing::trace!("background endpoint subscriber shutting down");
...@@ -306,7 +233,7 @@ impl KvScheduler { ...@@ -306,7 +233,7 @@ impl KvScheduler {
let response = resp_rx let response = resp_rx
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)??;
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
let total_elapsed = start.elapsed(); let total_elapsed = start.elapsed();
...@@ -430,25 +357,14 @@ fn softmax_sample( ...@@ -430,25 +357,14 @@ fn softmax_sample(
// All values are the same, uniform probability // All values are the same, uniform probability
vec![1.0 / keys.len() as f64; keys.len()] vec![1.0 / keys.len() as f64; keys.len()]
} else { } else {
// Normalize values // Fused normalize → negate → scale → exp, then normalize probabilities
let normalized: Vec<_> = values let range = max_val - min_val;
.iter() let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect();
.map(|&v| {
// Lower is better, so negate
// Note we don't need to do actual min-max norm here, just off by an offset
let norm = v / (max_val - min_val);
-norm
})
.collect();
// Apply temperature and softmax
let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect(); let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
let sum: f64 = probs.iter().sum();
let sum_exp: f64 = exp_values.iter().sum(); probs.iter_mut().for_each(|p| *p /= sum);
exp_values.iter().map(|&v| v / sum_exp).collect() probs
}; };
// Sample from the probability distribution // Sample from the probability distribution
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment