Unverified Commit e3d00b89 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Tier-based KV Routing (#8380)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7e48f3bd
...@@ -44,6 +44,7 @@ fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) { ...@@ -44,6 +44,7 @@ fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) {
} }
mod kv_indexer; mod kv_indexer;
mod local; mod local;
mod lower_tier;
mod metrics; mod metrics;
mod thread_pool; mod thread_pool;
mod traits; mod traits;
...@@ -62,6 +63,7 @@ mod tests; ...@@ -62,6 +63,7 @@ mod tests;
pub use branch_sharded::*; pub use branch_sharded::*;
pub use kv_indexer::*; pub use kv_indexer::*;
pub use local::*; pub use local::*;
pub use lower_tier::*;
pub use metrics::*; pub use metrics::*;
pub use thread_pool::*; pub use thread_pool::*;
pub use traits::*; pub use traits::*;
......
...@@ -23,7 +23,7 @@ use std::{ ...@@ -23,7 +23,7 @@ use std::{
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use super::{EventWarningKind, PreBoundEventCounters}; use super::{EventWarningKind, MatchDetails, PreBoundEventCounters};
use crate::active_set::reconcile_active_workers; use crate::active_set::reconcile_active_workers;
use crate::protocols::*; use crate::protocols::*;
...@@ -162,12 +162,20 @@ impl RadixTree { ...@@ -162,12 +162,20 @@ impl RadixTree {
/// ///
/// ### Returns /// ### Returns
/// ///
/// An `OverlapScores` representing the match scores. /// A `MatchDetails` representing overlap scores plus continuation state.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores { pub fn find_match_details(
let mut scores = OverlapScores::new(); &self,
sequence: Vec<LocalBlockHash>,
early_exit: bool,
) -> MatchDetails {
let mut details = MatchDetails::new();
let MatchDetails {
overlap_scores: scores,
last_matched_hashes,
} = &mut details;
if sequence.is_empty() { if sequence.is_empty() {
return scores; return details;
} }
let now = Instant::now(); let now = Instant::now();
...@@ -184,7 +192,7 @@ impl RadixTree { ...@@ -184,7 +192,7 @@ impl RadixTree {
}; };
let Some(first_child) = first_child else { let Some(first_child) = first_child else {
return scores; return details;
}; };
// Initialize active worker set from first child. // Initialize active worker set from first child.
...@@ -208,12 +216,18 @@ impl RadixTree { ...@@ -208,12 +216,18 @@ impl RadixTree {
} }
if active.is_empty() { if active.is_empty() {
return scores; return details;
} }
let mut current_hash = first_child
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
if early_exit && active_count == 1 { if early_exit && active_count == 1 {
for worker in &active { for worker in &active {
scores.scores.insert(*worker, 1); scores.scores.insert(*worker, 1);
last_matched_hashes.insert(*worker, current_hash);
} }
for worker in scores.scores.keys() { for worker in scores.scores.keys() {
let tree_size = self let tree_size = self
...@@ -223,7 +237,7 @@ impl RadixTree { ...@@ -223,7 +237,7 @@ impl RadixTree {
.len(); .len();
scores.tree_sizes.insert(*worker, tree_size); scores.tree_sizes.insert(*worker, tree_size);
} }
return scores; return details;
} }
let mut current = first_child; let mut current = first_child;
...@@ -256,6 +270,7 @@ impl RadixTree { ...@@ -256,6 +270,7 @@ impl RadixTree {
if child_count != active_count { if child_count != active_count {
reconcile_active_workers(&mut active, &borrow.workers, |worker| { reconcile_active_workers(&mut active, &borrow.workers, |worker| {
scores.scores.insert(worker, matched_depth); scores.scores.insert(worker, matched_depth);
last_matched_hashes.insert(worker, current_hash);
}); });
active_count = active.len(); active_count = active.len();
} }
...@@ -281,9 +296,17 @@ impl RadixTree { ...@@ -281,9 +296,17 @@ impl RadixTree {
if early_exit && active_count == 1 { if early_exit && active_count == 1 {
matched_depth = (idx + 1) as u32; matched_depth = (idx + 1) as u32;
current_hash = block
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
break; break;
} }
current_hash = block
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
current = block; current = block;
matched_depth = (idx + 1) as u32; matched_depth = (idx + 1) as u32;
} }
...@@ -291,6 +314,7 @@ impl RadixTree { ...@@ -291,6 +314,7 @@ impl RadixTree {
// Record scores for workers that survived through the deepest matched level. // Record scores for workers that survived through the deepest matched level.
for worker in &active { for worker in &active {
scores.scores.insert(*worker, matched_depth); scores.scores.insert(*worker, matched_depth);
last_matched_hashes.insert(*worker, current_hash);
} }
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores); tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
...@@ -305,7 +329,12 @@ impl RadixTree { ...@@ -305,7 +329,12 @@ impl RadixTree {
scores.tree_sizes.insert(*worker, tree_size); scores.tree_sizes.insert(*worker, tree_size);
} }
scores details
}
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
self.find_match_details(sequence, early_exit).overlap_scores
} }
/// Apply a [`RouterEvent`] to the radix tree. /// Apply a [`RouterEvent`] to the radix tree.
......
...@@ -9,6 +9,7 @@ use tokio::sync::oneshot; ...@@ -9,6 +9,7 @@ use tokio::sync::oneshot;
use crate::protocols::*; use crate::protocols::*;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use rustc_hash::FxHashMap;
/// Trait for types that may represent an error response. /// Trait for types that may represent an error response.
/// Used for RPC-style responses that can indicate success or failure. /// Used for RPC-style responses that can indicate success or failure.
...@@ -250,6 +251,21 @@ impl dynamo_runtime::protocols::maybe_error::MaybeError for IndexerRecordRouting ...@@ -250,6 +251,21 @@ impl dynamo_runtime::protocols::maybe_error::MaybeError for IndexerRecordRouting
} }
} }
/// Rich non-wire query result for router-local device tier lookups.
#[derive(Debug, Clone, Default)]
pub struct MatchDetails {
/// Existing overlap scores used by scheduling.
pub overlap_scores: OverlapScores,
/// Last matched device sequence hash per worker, used to seed lower-tier queries.
pub last_matched_hashes: FxHashMap<WorkerWithDpRank, ExternalSequenceBlockHash>,
}
impl MatchDetails {
pub fn new() -> Self {
Self::default()
}
}
/// A request to find matches in the Radix Tree. /// A request to find matches in the Radix Tree.
pub struct MatchRequest { pub struct MatchRequest {
/// A vector of `LocalBlockHash` representing the sequence to match. /// A vector of `LocalBlockHash` representing the sequence to match.
...@@ -279,6 +295,30 @@ impl MatchRequest { ...@@ -279,6 +295,30 @@ impl MatchRequest {
} }
} }
/// A request to find matches while also returning continuation metadata.
pub struct MatchDetailsRequest {
/// A vector of `LocalBlockHash` representing the sequence to match.
pub sequence: Vec<LocalBlockHash>,
/// A boolean indicating whether to exit early if a single match is found.
pub early_exit: bool,
/// A channel sender to send the `MatchDetails` response.
pub resp: oneshot::Sender<MatchDetails>,
}
impl MatchDetailsRequest {
pub(super) fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: oneshot::Sender<MatchDetails>,
) -> Self {
Self {
sequence,
early_exit,
resp,
}
}
}
/// A request to dump the tree as events /// A request to dump the tree as events
pub struct DumpRequest { pub struct DumpRequest {
/// Channel to send the dumped events /// Channel to send the dumped events
......
...@@ -51,7 +51,8 @@ pub use config::{ ...@@ -51,7 +51,8 @@ pub use config::{
SharedCacheType, SharedCacheType,
}; };
pub use indexer::{ pub use indexer::{
BranchShardedIndexer, MaybeError, SharedKvCache, SyncIndexer, ThreadPoolIndexer, BranchShardedIndexer, LowerTierContinuation, LowerTierIndexer, MaybeError, SharedKvCache,
SyncIndexer, ThreadPoolIndexer,
}; };
pub use nested_map::PositionalIndexer; pub use nested_map::PositionalIndexer;
pub use protocols::{ pub use protocols::{
......
...@@ -355,9 +355,12 @@ pub struct WorkerSelectionResult { ...@@ -355,9 +355,12 @@ pub struct WorkerSelectionResult {
/// The total number of blocks required to prefill the request /// The total number of blocks required to prefill the request
pub required_blocks: u64, pub required_blocks: u64,
/// The number of blocks that the selected worker may already have cached. /// Approximate effective cache hit on the selected worker in fractional blocks.
/// This is not a guarantee, but an estimate. /// Use `.round() as u32` for a block-count approximation.
pub overlap_blocks: u32, pub effective_overlap_blocks: f64,
/// Approximate cached-token count derived from the weighted cache hit.
pub cached_tokens: usize,
} }
/// Active load metrics for a worker, used for busy detection. /// Active load metrics for a worker, used for busy detection.
......
...@@ -35,6 +35,14 @@ pub fn min_initial_workers_from_env() -> anyhow::Result<usize> { ...@@ -35,6 +35,14 @@ pub fn min_initial_workers_from_env() -> anyhow::Result<usize> {
} }
} }
const fn default_host_cache_hit_weight() -> f64 {
0.75
}
const fn default_disk_cache_hit_weight() -> f64 {
0.25
}
/// Type of external shared KV cache to query during routing. /// Type of external shared KV cache to query during routing.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
...@@ -170,6 +178,14 @@ pub struct KvRouterConfig { ...@@ -170,6 +178,14 @@ pub struct KvRouterConfig {
#[validate(range(min = 0.0))] #[validate(range(min = 0.0))]
pub overlap_score_weight: f64, pub overlap_score_weight: f64,
#[serde(default = "default_host_cache_hit_weight")]
#[validate(range(min = 0.0, max = 1.0))]
pub host_cache_hit_weight: f64,
#[serde(default = "default_disk_cache_hit_weight")]
#[validate(range(min = 0.0, max = 1.0))]
pub disk_cache_hit_weight: f64,
#[validate(range(min = 0.0))] #[validate(range(min = 0.0))]
pub router_temperature: f64, pub router_temperature: f64,
...@@ -269,6 +285,8 @@ impl Default for KvRouterConfig { ...@@ -269,6 +285,8 @@ impl Default for KvRouterConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
overlap_score_weight: 1.0, overlap_score_weight: 1.0,
host_cache_hit_weight: default_host_cache_hit_weight(),
disk_cache_hit_weight: default_disk_cache_hit_weight(),
router_temperature: 0.0, router_temperature: 0.0,
use_kv_events: true, use_kv_events: true,
durable_kv_events: false, // default to NATS Core (local indexer mode) durable_kv_events: false, // default to NATS Core (local indexer mode)
......
...@@ -14,8 +14,10 @@ use super::policy::{RouterSchedulingPolicy, SchedulingPolicy}; ...@@ -14,8 +14,10 @@ use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::prefill_load::PrefillLoadEstimator; use super::prefill_load::PrefillLoadEstimator;
use super::queue::SchedulerQueue; use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector}; use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse}; use super::types::{
use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank}; KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse, TierOverlapBlocks,
};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{ use crate::sequences::{
ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest, ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
}; };
...@@ -42,6 +44,18 @@ where ...@@ -42,6 +44,18 @@ where
S: SchedulingPolicy + 'static, S: SchedulingPolicy + 'static,
Sel: WorkerSelector<C> + Send + Sync + 'static, Sel: WorkerSelector<C> + Send + Sync + 'static,
{ {
fn worker_dp_range(workers: &HashMap<WorkerId, C>) -> HashMap<WorkerId, (u32, u32)> {
workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect()
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>, slots: Arc<ActiveSequencesMultiWorker<P>>,
...@@ -84,15 +98,7 @@ where ...@@ -84,15 +98,7 @@ where
continue; continue;
} }
let dp_range: HashMap<WorkerId, (u32, u32)> = current_workers let dp_range = Self::worker_dp_range(&current_workers);
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect();
slots_monitor.update_workers(&dp_range); slots_monitor.update_workers(&dp_range);
last_workers = current_workers; last_workers = current_workers;
} }
...@@ -168,7 +174,10 @@ where ...@@ -168,7 +174,10 @@ where
maybe_request_id: Option<String>, maybe_request_id: Option<String>,
isl_tokens: usize, isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores, tier_overlap_blocks: TierOverlapBlocks,
effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
tree_sizes: HashMap<WorkerWithDpRank, usize>,
router_config_override: Option<&super::config::RouterConfigOverride>, router_config_override: Option<&super::config::RouterConfigOverride>,
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
...@@ -186,7 +195,10 @@ where ...@@ -186,7 +195,10 @@ where
maybe_request_id, maybe_request_id,
token_seq, token_seq,
isl_tokens, isl_tokens,
overlaps, tier_overlap_blocks,
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes,
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens, track_prefill_tokens,
...@@ -258,7 +270,7 @@ where ...@@ -258,7 +270,7 @@ where
&self, &self,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize, isl_tokens: usize,
overlaps: OverlapScores, effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
track_prefill_tokens: bool, track_prefill_tokens: bool,
) -> Vec<PotentialLoad> { ) -> Vec<PotentialLoad> {
let decay_now = Instant::now(); let decay_now = Instant::now();
...@@ -267,7 +279,7 @@ where ...@@ -267,7 +279,7 @@ where
.potential_blocks_and_tokens_with_prefill_tracking( .potential_blocks_and_tokens_with_prefill_tracking(
token_seq.as_deref(), token_seq.as_deref(),
isl_tokens, isl_tokens,
overlaps, effective_cached_tokens,
track_prefill_tokens, track_prefill_tokens,
decay_now, decay_now,
); );
...@@ -306,7 +318,7 @@ mod tests { ...@@ -306,7 +318,7 @@ mod tests {
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use super::*; use super::*;
use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores}; use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData};
use crate::scheduling::PrefillLoadEstimator; use crate::scheduling::PrefillLoadEstimator;
use crate::scheduling::policy::FcfsPolicy; use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector; use crate::scheduling::selector::DefaultWorkerSelector;
...@@ -423,7 +435,10 @@ mod tests { ...@@ -423,7 +435,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
Some("adapter-a".to_string()), Some("adapter-a".to_string()),
...@@ -462,7 +477,10 @@ mod tests { ...@@ -462,7 +477,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
Some(&crate::config::RouterConfigOverride { Some(&crate::config::RouterConfigOverride {
track_prefill_tokens: Some(false), track_prefill_tokens: Some(false),
..Default::default() ..Default::default()
...@@ -489,6 +507,52 @@ mod tests { ...@@ -489,6 +507,52 @@ mod tests {
cancel_token.cancel(); cancel_token.cancel();
} }
#[tokio::test]
async fn test_schedule_uses_weighted_cached_tokens_for_active_tracking() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let worker = WorkerWithDpRank::new(0, 0);
let response = scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
TierOverlapBlocks::default(),
HashMap::from([(worker, 0.75)]),
HashMap::from([(worker, 48)]),
HashMap::new(),
None,
true,
None,
0.0,
None,
None,
None,
None,
)
.await
.unwrap();
assert_eq!(response.best_worker, worker);
assert_eq!(response.cached_tokens, 48);
assert_eq!(response.effective_overlap_blocks, 0.75);
assert_eq!(
slots.active_tokens(Instant::now()).get(&worker).copied(),
Some(16),
"weighted cached tokens should reduce tracked prefill load",
);
cancel_token.cancel();
}
#[tokio::test] #[tokio::test]
async fn test_mark_prefill_completed_drains_pending_queue() { async fn test_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new(); let mut workers = HashMap::new();
...@@ -507,7 +571,10 @@ mod tests { ...@@ -507,7 +571,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -528,7 +595,10 @@ mod tests { ...@@ -528,7 +595,10 @@ mod tests {
Some("req-2".to_string()), Some("req-2".to_string()),
64, 64,
Some(vec![5, 6, 7, 8]), Some(vec![5, 6, 7, 8]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -570,7 +640,10 @@ mod tests { ...@@ -570,7 +640,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -591,7 +664,10 @@ mod tests { ...@@ -591,7 +664,10 @@ mod tests {
Some("req-2".to_string()), Some("req-2".to_string()),
64, 64,
Some(vec![5, 6, 7, 8]), Some(vec![5, 6, 7, 8]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -647,7 +723,10 @@ mod tests { ...@@ -647,7 +723,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -668,7 +747,10 @@ mod tests { ...@@ -668,7 +747,10 @@ mod tests {
Some("req-2".to_string()), Some("req-2".to_string()),
64, 64,
Some(vec![5, 6, 7, 8]), Some(vec![5, 6, 7, 8]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -723,7 +805,10 @@ mod tests { ...@@ -723,7 +805,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -744,7 +829,10 @@ mod tests { ...@@ -744,7 +829,10 @@ mod tests {
Some("req-2".to_string()), Some("req-2".to_string()),
64, 64,
Some(vec![5, 6, 7, 8]), Some(vec![5, 6, 7, 8]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -797,7 +885,10 @@ mod tests { ...@@ -797,7 +885,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
Some("adapter-a".to_string()), Some("adapter-a".to_string()),
...@@ -839,14 +930,10 @@ mod tests { ...@@ -839,14 +930,10 @@ mod tests {
); );
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let token_seq = vec![11, 22, 33, 44]; let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default(); let cached_tokens = HashMap::new();
let (decode_blocks, prefill_tokens) = slots.potential_blocks_and_tokens( let (decode_blocks, prefill_tokens) =
Some(&token_seq), slots.potential_blocks_and_tokens(Some(&token_seq), 128, cached_tokens.clone());
128,
overlaps.clone(),
Instant::now(),
);
let mut expected: Vec<_> = decode_blocks let mut expected: Vec<_> = decode_blocks
.keys() .keys()
.map(|worker| PotentialLoad { .map(|worker| PotentialLoad {
...@@ -858,7 +945,7 @@ mod tests { ...@@ -858,7 +945,7 @@ mod tests {
.collect(); .collect();
expected.sort_by_key(|load| (load.worker_id, load.dp_rank)); expected.sort_by_key(|load| (load.worker_id, load.dp_rank));
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps, true); let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, cached_tokens, true);
actual.sort_by_key(|load| (load.worker_id, load.dp_rank)); actual.sort_by_key(|load| (load.worker_id, load.dp_rank));
assert_eq!(actual.len(), expected.len()); assert_eq!(actual.len(), expected.len());
...@@ -899,7 +986,10 @@ mod tests { ...@@ -899,7 +986,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
100, 100,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -914,7 +1004,7 @@ mod tests { ...@@ -914,7 +1004,7 @@ mod tests {
tokio::time::advance(Duration::from_secs(6)).await; tokio::time::advance(Duration::from_secs(6)).await;
let loads = scheduler.get_potential_loads(None, 0, OverlapScores::default(), true); let loads = scheduler.get_potential_loads(None, 0, HashMap::new(), true);
assert_eq!(loads.len(), 1); assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 40); assert_eq!(loads[0].potential_prefill_tokens, 40);
...@@ -927,7 +1017,7 @@ mod tests { ...@@ -927,7 +1017,7 @@ mod tests {
make_scheduler(HashMap::new(), None, false, None); make_scheduler(HashMap::new(), None, false, None);
scheduler.register_workers(&HashSet::from([42])); scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), true); let loads = scheduler.get_potential_loads(None, 64, HashMap::new(), true);
assert_eq!(loads.len(), 1); assert_eq!(loads.len(), 1);
assert_eq!(loads[0].worker_id, 42); assert_eq!(loads[0].worker_id, 42);
...@@ -944,7 +1034,7 @@ mod tests { ...@@ -944,7 +1034,7 @@ mod tests {
assert_eq!( assert_eq!(
scheduler scheduler
.get_potential_loads(None, 64, OverlapScores::default(), true) .get_potential_loads(None, 64, HashMap::new(), true,)
.len(), .len(),
1 1
); );
...@@ -963,7 +1053,7 @@ mod tests { ...@@ -963,7 +1053,7 @@ mod tests {
tokio::time::timeout(Duration::from_secs(1), async { tokio::time::timeout(Duration::from_secs(1), async {
loop { loop {
if scheduler if scheduler
.get_potential_loads(None, 64, OverlapScores::default(), true) .get_potential_loads(None, 64, HashMap::new(), true)
.len() .len()
== 3 == 3
{ {
...@@ -995,7 +1085,10 @@ mod tests { ...@@ -995,7 +1085,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![11, 22]), Some(vec![11, 22]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -1008,7 +1101,7 @@ mod tests { ...@@ -1008,7 +1101,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), false); let loads = scheduler.get_potential_loads(None, 64, HashMap::new(), false);
assert_eq!(loads.len(), 1); assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 64); assert_eq!(loads[0].potential_prefill_tokens, 64);
......
...@@ -66,16 +66,34 @@ impl SchedulingPolicy for LcfsPolicy { ...@@ -66,16 +66,34 @@ impl SchedulingPolicy for LcfsPolicy {
/// Optimizes for average TTFT — minimizes total weighted completion time /// Optimizes for average TTFT — minimizes total weighted completion time
/// (Smith 1956). Short or high-priority requests are scheduled before /// (Smith 1956). Short or high-priority requests are scheduled before
/// long low-priority ones, reducing mean latency across the batch. /// long low-priority ones, reducing mean latency across the batch.
pub struct WsptPolicy { pub struct WsptPolicy;
pub block_size: usize,
}
impl SchedulingPolicy for WsptPolicy { impl SchedulingPolicy for WsptPolicy {
type Key = OrderedFloat<f64>; type Key = OrderedFloat<f64>;
fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key { fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
let weight = 1.0 + request.priority_jump.max(0.0); let weight = 1.0 + request.priority_jump.max(0.0);
let cached_tokens = request.overlap_blocks() as usize * self.block_size; let allowed_ids = request.allowed_worker_ids.as_ref();
let cached_tokens = request.pinned_worker.map_or_else(
|| {
request
.effective_cached_tokens
.iter()
.filter(|(worker, _)| {
allowed_ids.is_none_or(|ids| ids.contains(&worker.worker_id))
})
.map(|(_, tokens)| *tokens)
.max()
.unwrap_or(0)
},
|worker| {
request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0)
},
);
let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1); let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1);
OrderedFloat(weight / new_tokens as f64) OrderedFloat(weight / new_tokens as f64)
} }
...@@ -91,11 +109,11 @@ pub enum RouterSchedulingPolicy { ...@@ -91,11 +109,11 @@ pub enum RouterSchedulingPolicy {
} }
impl RouterSchedulingPolicy { impl RouterSchedulingPolicy {
pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self { pub fn new(kind: RouterQueuePolicy) -> Self {
match kind { match kind {
RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy), RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy), RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }), RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy),
} }
} }
} }
...@@ -124,11 +142,24 @@ mod tests { ...@@ -124,11 +142,24 @@ mod tests {
priority_jump: f64, priority_jump: f64,
overlaps: OverlapScores, overlaps: OverlapScores,
) -> SchedulingRequest { ) -> SchedulingRequest {
let effective_overlap_blocks = overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as f64))
.collect();
let effective_cached_tokens = overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as usize * 16))
.collect();
SchedulingRequest { SchedulingRequest {
maybe_request_id: None, maybe_request_id: None,
token_seq: None, token_seq: None,
isl_tokens, isl_tokens,
overlaps, tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes: std::collections::HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
...@@ -224,10 +255,10 @@ mod tests { ...@@ -224,10 +255,10 @@ mod tests {
let early = Duration::from_secs(1); let early = Duration::from_secs(1);
let late = Duration::from_secs(10); let late = Duration::from_secs(10);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs, 16); let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs);
assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req)); assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req));
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs, 16); let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs);
assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req)); assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req));
} }
...@@ -235,7 +266,7 @@ mod tests { ...@@ -235,7 +266,7 @@ mod tests {
#[test] #[test]
fn wspt_shorter_request_scheduled_first() { fn wspt_shorter_request_scheduled_first() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let short = request_with(100, 0.0, OverlapScores::default()); let short = request_with(100, 0.0, OverlapScores::default());
let long = request_with(1000, 0.0, OverlapScores::default()); let long = request_with(1000, 0.0, OverlapScores::default());
let t = Duration::ZERO; let t = Duration::ZERO;
...@@ -247,7 +278,7 @@ mod tests { ...@@ -247,7 +278,7 @@ mod tests {
#[test] #[test]
fn wspt_overlap_reduces_effective_cost() { fn wspt_overlap_reduces_effective_cost() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
// Both 1024 ISL tokens, but one has 60 blocks cached (960 tokens). // Both 1024 ISL tokens, but one has 60 blocks cached (960 tokens).
let no_cache = request_with(1024, 0.0, OverlapScores::default()); let no_cache = request_with(1024, 0.0, OverlapScores::default());
let cached = request_with(1024, 0.0, overlaps_from(&[(0, 60)])); let cached = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
...@@ -262,7 +293,7 @@ mod tests { ...@@ -262,7 +293,7 @@ mod tests {
#[test] #[test]
fn wspt_priority_promotes() { fn wspt_priority_promotes() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let normal = request_with(512, 0.0, OverlapScores::default()); let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 5.0, OverlapScores::default()); let boosted = request_with(512, 5.0, OverlapScores::default());
let t = Duration::ZERO; let t = Duration::ZERO;
...@@ -274,7 +305,7 @@ mod tests { ...@@ -274,7 +305,7 @@ mod tests {
#[test] #[test]
fn wspt_uses_max_overlap() { fn wspt_uses_max_overlap() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
// 4 workers with overlaps [10, 20, 50, 60]. max = 60. // 4 workers with overlaps [10, 20, 50, 60]. max = 60.
// new_tokens = 1024 - 60*16 = 64 // new_tokens = 1024 - 60*16 = 64
let req = request_with( let req = request_with(
...@@ -289,7 +320,7 @@ mod tests { ...@@ -289,7 +320,7 @@ mod tests {
#[test] #[test]
fn wspt_uses_pinned_worker_overlap_when_present() { fn wspt_uses_pinned_worker_overlap_when_present() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60), (1, 1)])); let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60), (1, 1)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0)); req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
...@@ -300,7 +331,7 @@ mod tests { ...@@ -300,7 +331,7 @@ mod tests {
#[test] #[test]
fn wspt_missing_pinned_overlap_uses_zero() { fn wspt_missing_pinned_overlap_uses_zero() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60)])); let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0)); req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
...@@ -311,7 +342,7 @@ mod tests { ...@@ -311,7 +342,7 @@ mod tests {
#[test] #[test]
fn wspt_no_overlap_falls_back_to_isl() { fn wspt_no_overlap_falls_back_to_isl() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let req = request_with(512, 0.0, OverlapScores::default()); let req = request_with(512, 0.0, OverlapScores::default());
let key = policy.enqueue_key(Duration::ZERO, &req); let key = policy.enqueue_key(Duration::ZERO, &req);
let expected = OrderedFloat(1.0 / 512.0); let expected = OrderedFloat(1.0 / 512.0);
...@@ -320,7 +351,7 @@ mod tests { ...@@ -320,7 +351,7 @@ mod tests {
#[test] #[test]
fn wspt_full_overlap_clamps_to_one() { fn wspt_full_overlap_clamps_to_one() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
// 512 tokens, 64 blocks cached = 1024 cached tokens > ISL → saturating_sub → 0 → max(1) // 512 tokens, 64 blocks cached = 1024 cached tokens > ISL → saturating_sub → 0 → max(1)
let req = request_with(512, 0.0, overlaps_from(&[(0, 64)])); let req = request_with(512, 0.0, overlaps_from(&[(0, 64)]));
let key = policy.enqueue_key(Duration::ZERO, &req); let key = policy.enqueue_key(Duration::ZERO, &req);
......
...@@ -239,7 +239,7 @@ impl< ...@@ -239,7 +239,7 @@ impl<
.potential_blocks_and_tokens_with_prefill_tracking( .potential_blocks_and_tokens_with_prefill_tracking(
request.token_seq.as_deref(), request.token_seq.as_deref(),
request.isl_tokens, request.isl_tokens,
request.overlaps.clone(), request.effective_cached_tokens.clone(),
request.track_prefill_tokens, request.track_prefill_tokens,
decay_now, decay_now,
); );
...@@ -263,7 +263,8 @@ impl< ...@@ -263,7 +263,8 @@ impl<
request.respond(Ok(SchedulingResponse { request.respond(Ok(SchedulingResponse {
best_worker: selection.worker, best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks, effective_overlap_blocks: selection.effective_overlap_blocks,
cached_tokens: selection.cached_tokens,
})); }));
if !request.update_states { if !request.update_states {
...@@ -277,7 +278,7 @@ impl< ...@@ -277,7 +278,7 @@ impl<
let prefill_load_hint = self.prefill_load_hint_for( let prefill_load_hint = self.prefill_load_hint_for(
request.isl_tokens, request.isl_tokens,
selection.overlap_blocks, selection.cached_tokens,
request.track_prefill_tokens, request.track_prefill_tokens,
); );
...@@ -291,7 +292,7 @@ impl< ...@@ -291,7 +292,7 @@ impl<
worker: selection.worker, worker: selection.worker,
lora_name: request.lora_name.clone(), lora_name: request.lora_name.clone(),
}, },
decay_now, Instant::now(),
) { ) {
tracing::warn!("Failed to add request {request_id}: {e}"); tracing::warn!("Failed to add request {request_id}: {e}");
} }
...@@ -300,14 +301,14 @@ impl< ...@@ -300,14 +301,14 @@ impl<
fn prefill_load_hint_for( fn prefill_load_hint_for(
&self, &self,
isl_tokens: usize, isl_tokens: usize,
overlap_blocks: u32, cached_tokens: usize,
track_prefill_tokens: bool, track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> { ) -> Option<PrefillLoadHint> {
if !track_prefill_tokens { if !track_prefill_tokens {
return None; return None;
} }
let prefix = (overlap_blocks as usize) * (self.block_size as usize); let prefix = cached_tokens.min(isl_tokens);
let effective_isl = isl_tokens.saturating_sub(prefix); let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 { if effective_isl == 0 {
return None; return None;
...@@ -408,7 +409,7 @@ mod tests { ...@@ -408,7 +409,7 @@ mod tests {
use tokio::sync::{Barrier, watch}; use tokio::sync::{Barrier, watch};
use super::*; use super::*;
use crate::protocols::{OverlapScores, WorkerSelectionResult, WorkerWithDpRank}; use crate::protocols::{WorkerSelectionResult, WorkerWithDpRank};
use crate::scheduling::types::KvSchedulerError; use crate::scheduling::types::KvSchedulerError;
use crate::sequences::ActiveSequencesMultiWorker; use crate::sequences::ActiveSequencesMultiWorker;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
...@@ -499,7 +500,16 @@ mod tests { ...@@ -499,7 +500,16 @@ mod tests {
Ok(WorkerSelectionResult { Ok(WorkerSelectionResult {
worker, worker,
required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64, required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64,
overlap_blocks: request.overlaps.scores.get(&worker).copied().unwrap_or(0), effective_overlap_blocks: request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0),
cached_tokens: request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0),
}) })
} }
} }
...@@ -628,7 +638,10 @@ mod tests { ...@@ -628,7 +638,10 @@ mod tests {
maybe_request_id: Some(request_id.to_string()), maybe_request_id: Some(request_id.to_string()),
token_seq: None, token_seq: None,
isl_tokens, isl_tokens,
overlaps: OverlapScores::default(), tier_overlap_blocks: Default::default(),
effective_overlap_blocks: HashMap::new(),
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
...@@ -1020,7 +1033,10 @@ mod tests { ...@@ -1020,7 +1033,10 @@ mod tests {
maybe_request_id: Some("filter-0".to_string()), maybe_request_id: Some("filter-0".to_string()),
token_seq: None, token_seq: None,
isl_tokens: isl, isl_tokens: isl,
overlaps: OverlapScores::default(), tier_overlap_blocks: Default::default(),
effective_overlap_blocks: HashMap::new(),
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
......
...@@ -37,59 +37,49 @@ fn softmax_sample_with_sample( ...@@ -37,59 +37,49 @@ fn softmax_sample_with_sample(
temperature: f64, temperature: f64,
sample: f64, sample: f64,
) -> (WorkerWithDpRank, f64) { ) -> (WorkerWithDpRank, f64) {
if logits.is_empty() { assert!(!logits.is_empty(), "Empty logits for softmax sampling");
panic!("Empty logits for softmax sampling");
}
// Guard: at zero temperature, return a minimum-logit worker directly.
if temperature == 0.0 { if temperature == 0.0 {
let mut logit_iter = logits.iter(); let (worker, logit) = logits
let (first_key, first_logit) = logit_iter.next().unwrap(); .iter()
.min_by(|a, b| a.1.total_cmp(b.1))
let mut min_logit = first_logit; .expect("logits non-empty");
let mut min_key = first_key; return (*worker, *logit);
for (key, logit) in logit_iter {
if logit < min_logit {
min_logit = logit;
min_key = key;
}
}
return (*min_key, *min_logit);
} }
let entries: Vec<_> = logits let entries: Vec<(WorkerWithDpRank, f64)> = logits.iter().map(|(w, l)| (*w, *l)).collect();
.iter()
.map(|(worker, logit)| (*worker, *logit))
.collect();
let values: Vec<_> = entries.iter().map(|(_, logit)| *logit).collect();
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b)); let (min_val, max_val) = entries
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); .iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), (_, v)| {
(lo.min(*v), hi.max(*v))
});
let probabilities = if min_val == max_val { let mut probs = if min_val == max_val {
vec![1.0 / entries.len() as f64; entries.len()] vec![1.0 / entries.len() as f64; entries.len()]
} else { } else {
// Fused normalize -> negate -> scale -> exp, then normalize probabilities // Negate logits and rescale to [−1/temperature, 0] for numerical stability
let range = max_val - min_val; // before softmax. Subtracting the max (which maps to min_val) keeps exp() inputs ≤ 0.
let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect(); let scale = -1.0 / ((max_val - min_val) * temperature);
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); let max_scaled = min_val * scale;
let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect(); entries
let sum: f64 = probs.iter().sum(); .iter()
probs.iter_mut().for_each(|p| *p /= sum); .map(|(_, v)| (v * scale - max_scaled).exp())
probs .collect::<Vec<f64>>()
}; };
let sum: f64 = probs.iter().sum();
probs.iter_mut().for_each(|p| *p /= sum);
let mut cumsum = 0.0; let mut cumsum = 0.0;
for (i, &prob) in probabilities.iter().enumerate() { for (i, &prob) in probs.iter().enumerate() {
cumsum += prob; cumsum += prob;
if sample <= cumsum { if sample <= cumsum {
return entries[i]; return entries[i];
} }
} }
// Fallback to last key (shouldn't normally reach here) *entries.last().unwrap()
entries[entries.len() - 1]
} }
/// Default implementation matching the Python _cost_function. /// Default implementation matching the Python _cost_function.
...@@ -99,12 +89,6 @@ pub struct DefaultWorkerSelector { ...@@ -99,12 +89,6 @@ pub struct DefaultWorkerSelector {
pub worker_type: &'static str, pub worker_type: &'static str,
} }
#[derive(Debug, Clone, Copy)]
struct WorkerScore {
overlap_blocks: u32,
logit: f64,
}
impl DefaultWorkerSelector { impl DefaultWorkerSelector {
pub fn new(kv_router_config: Option<KvRouterConfig>, worker_type: &'static str) -> Self { pub fn new(kv_router_config: Option<KvRouterConfig>, worker_type: &'static str) -> Self {
Self { Self {
...@@ -113,7 +97,7 @@ impl DefaultWorkerSelector { ...@@ -113,7 +97,7 @@ impl DefaultWorkerSelector {
} }
} }
fn worker_score( fn worker_logit(
&self, &self,
request: &SchedulingRequest, request: &SchedulingRequest,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
...@@ -121,9 +105,16 @@ impl DefaultWorkerSelector { ...@@ -121,9 +105,16 @@ impl DefaultWorkerSelector {
overlap_weight: f64, overlap_weight: f64,
shared_cache_multiplier: f64, shared_cache_multiplier: f64,
formula_name: &'static str, formula_name: &'static str,
) -> WorkerScore { ) -> f64 {
let isl = request.isl_tokens; let isl = request.isl_tokens;
let overlap_blocks = request.overlaps.scores.get(&worker).copied().unwrap_or(0); let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0);
// `shared_cache_hits::hits_beyond` expects an integer block count, so
// round the weighted overlap for this comparison only.
let device_overlap_blocks = effective_overlap_blocks.round().max(0.0) as u32;
let default_prefill_token = if request.track_prefill_tokens { isl } else { 0 }; let default_prefill_token = if request.track_prefill_tokens { isl } else { 0 };
let prefill_token = request let prefill_token = request
.prefill_tokens .prefill_tokens
...@@ -134,7 +125,7 @@ impl DefaultWorkerSelector { ...@@ -134,7 +125,7 @@ impl DefaultWorkerSelector {
// Adjust prefill tokens by shared cache hits beyond this worker's device prefix. // Adjust prefill tokens by shared cache hits beyond this worker's device prefix.
let (adjusted_prefill_token, shared_beyond) = let (adjusted_prefill_token, shared_beyond) =
if let Some(ref shared_hits) = request.shared_cache_hits { if let Some(ref shared_hits) = request.shared_cache_hits {
let beyond = shared_hits.hits_beyond(overlap_blocks); let beyond = shared_hits.hits_beyond(device_overlap_blocks);
let reduction = shared_cache_multiplier * (beyond as f64) * (block_size as f64); let reduction = shared_cache_multiplier * (beyond as f64) * (block_size as f64);
let adjusted = (prefill_token as f64 - reduction).max(0.0) as usize; let adjusted = (prefill_token as f64 - reduction).max(0.0) as usize;
(adjusted, beyond) (adjusted, beyond)
...@@ -153,7 +144,7 @@ impl DefaultWorkerSelector { ...@@ -153,7 +144,7 @@ impl DefaultWorkerSelector {
if shared_beyond > 0 { if shared_beyond > 0 {
tracing::debug!( tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} device blocks, \ "{formula_name} for worker_id={} dp_rank={:?} with {effective_overlap_blocks:.2} effective device blocks, \
{shared_beyond} shared blocks beyond device (multiplier={shared_cache_multiplier:.2}): {logit:.3} \ {shared_beyond} shared blocks beyond device (multiplier={shared_cache_multiplier:.2}): {logit:.3} \
= {overlap_weight:.1} * adjusted_prefill_blocks + decode_blocks \ = {overlap_weight:.1} * adjusted_prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3} \ = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3} \
...@@ -163,7 +154,7 @@ impl DefaultWorkerSelector { ...@@ -163,7 +154,7 @@ impl DefaultWorkerSelector {
); );
} else { } else {
tracing::debug!( tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} cached blocks: {logit:.3} \ "{formula_name} for worker_id={} dp_rank={:?} with {effective_overlap_blocks:.2} effective cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \ = {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}", = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id, worker.worker_id,
...@@ -171,10 +162,7 @@ impl DefaultWorkerSelector { ...@@ -171,10 +162,7 @@ impl DefaultWorkerSelector {
); );
} }
WorkerScore { logit
overlap_blocks,
logit,
}
} }
} }
...@@ -201,7 +189,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -201,7 +189,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let isl = request.isl_tokens; let isl = request.isl_tokens;
let request_blocks = isl.div_ceil(block_size as usize); let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores;
let overlap_weight = request let overlap_weight = request
.router_config_override .router_config_override
...@@ -218,7 +205,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -218,7 +205,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
if let Some(worker) = pinned_worker { if let Some(worker) = pinned_worker {
pinned_worker_config(workers, worker)?; pinned_worker_config(workers, worker)?;
let score = self.worker_score( let logit = self.worker_logit(
request, request,
worker, worker,
block_size, block_size,
...@@ -226,11 +213,31 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -226,11 +213,31 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
shared_cache_multiplier, shared_cache_multiplier,
"Pinned formula", "Pinned formula",
); );
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0);
let cached_tokens = request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0);
tracing::info!(
"Selected pinned worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, effective cached blocks: {:.2}",
self.worker_type,
worker.worker_id,
worker.dp_rank,
logit,
effective_overlap_blocks,
);
return Ok(WorkerSelectionResult { return Ok(WorkerSelectionResult {
worker, worker,
required_blocks: request_blocks as u64, required_blocks: request_blocks as u64,
overlap_blocks: score.overlap_blocks, effective_overlap_blocks,
cached_tokens,
}); });
} }
...@@ -241,7 +248,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -241,7 +248,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.unwrap_or(self.kv_router_config.router_temperature); .unwrap_or(self.kv_router_config.router_temperature);
let get_score = |worker: WorkerWithDpRank| -> f64 { let get_score = |worker: WorkerWithDpRank| -> f64 {
self.worker_score( self.worker_logit(
request, request,
worker, worker,
block_size, block_size,
...@@ -249,7 +256,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -249,7 +256,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
shared_cache_multiplier, shared_cache_multiplier,
"Formula", "Formula",
) )
.logit
}; };
let worker_iter = workers let worker_iter = workers
...@@ -282,7 +288,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -282,7 +288,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
); );
let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = min_workers let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = min_workers
.iter() .iter()
.map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w)) .map(|w| (request.tree_sizes.get(w).copied().unwrap_or(0), w))
.collect(); .collect();
if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) { if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
...@@ -305,22 +311,58 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -305,22 +311,58 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
softmax_sample(&worker_logits, temperature) softmax_sample(&worker_logits, temperature)
}; };
let best_host_pinned_overlap_blocks = request
.tier_overlap_blocks
.host_pinned
.get(&best_worker)
.copied()
.unwrap_or(0);
let best_disk_overlap_blocks = request
.tier_overlap_blocks
.disk
.get(&best_worker)
.copied()
.unwrap_or(0);
if self.worker_type == "decode" { if self.worker_type == "decode" {
tracing::info!( tracing::info!(
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}", "Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, host_pinned blocks: {}, disk blocks: {}",
self.worker_type, self.worker_type,
best_worker.worker_id, best_worker.worker_id,
best_worker.dp_rank, best_worker.dp_rank,
best_logit, best_logit,
best_host_pinned_overlap_blocks,
best_disk_overlap_blocks,
); );
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&best_worker)
.copied()
.unwrap_or(0.0);
let cached_tokens = request
.effective_cached_tokens
.get(&best_worker)
.copied()
.unwrap_or(0);
return Ok(WorkerSelectionResult { return Ok(WorkerSelectionResult {
worker: best_worker, worker: best_worker,
required_blocks: request_blocks as u64, required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0), effective_overlap_blocks,
cached_tokens,
}); });
} }
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0); let best_overlap = request
.effective_overlap_blocks
.get(&best_worker)
.copied()
.unwrap_or(0.0);
let best_cached_tokens = request
.effective_cached_tokens
.get(&best_worker)
.copied()
.unwrap_or(0);
let total_blocks_info = workers let total_blocks_info = workers
.get(&best_worker.worker_id) .get(&best_worker.worker_id)
...@@ -328,20 +370,17 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -328,20 +370,17 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.map(|blocks| format!(", total blocks: {}", blocks)) .map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default(); .unwrap_or_default();
let tree_size = request let tree_size = request.tree_sizes.get(&best_worker).copied().unwrap_or(0);
.overlaps
.tree_sizes
.get(&best_worker)
.copied()
.unwrap_or(0);
tracing::info!( tracing::info!(
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}", "Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, effective cached blocks: {:.2}, host_pinned blocks: {}, disk blocks: {}, tree size: {}{}",
self.worker_type, self.worker_type,
best_worker.worker_id, best_worker.worker_id,
best_worker.dp_rank, best_worker.dp_rank,
best_logit, best_logit,
best_overlap, best_overlap,
best_host_pinned_overlap_blocks,
best_disk_overlap_blocks,
tree_size, tree_size,
total_blocks_info total_blocks_info
); );
...@@ -349,7 +388,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -349,7 +388,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
Ok(WorkerSelectionResult { Ok(WorkerSelectionResult {
worker: best_worker, worker: best_worker,
required_blocks: request_blocks as u64, required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0), effective_overlap_blocks: best_overlap,
cached_tokens: best_cached_tokens,
}) })
} }
} }
...@@ -479,7 +519,6 @@ mod tests { ...@@ -479,7 +519,6 @@ mod tests {
/// Worker 1 has lower logit (less work), so it wins. /// Worker 1 has lower logit (less work), so it wins.
#[test] #[test]
fn test_shared_cache_hits_scoring() { fn test_shared_cache_hits_scoring() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig; use crate::test_utils::SimpleWorkerConfig;
let block_size = 1u32; let block_size = 1u32;
...@@ -487,8 +526,8 @@ mod tests { ...@@ -487,8 +526,8 @@ mod tests {
let worker0 = WorkerWithDpRank::from_worker_id(0); let worker0 = WorkerWithDpRank::from_worker_id(0);
let worker1 = WorkerWithDpRank::from_worker_id(1); let worker1 = WorkerWithDpRank::from_worker_id(1);
let mut overlaps = OverlapScores::new(); let mut effective_overlap_blocks = HashMap::new();
overlaps.scores.insert(worker0, 2); effective_overlap_blocks.insert(worker0, 2.0);
// worker1 has 0 overlap (not in map) // worker1 has 0 overlap (not in map)
#[allow(clippy::single_range_in_vec_init)] #[allow(clippy::single_range_in_vec_init)]
...@@ -511,7 +550,10 @@ mod tests { ...@@ -511,7 +550,10 @@ mod tests {
maybe_request_id: Some("test".into()), maybe_request_id: Some("test".into()),
token_seq: None, token_seq: None,
isl_tokens: isl, isl_tokens: isl,
overlaps, tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
...@@ -540,15 +582,14 @@ mod tests { ...@@ -540,15 +582,14 @@ mod tests {
/// Without shared cache hits, the scoring should be unchanged. /// Without shared cache hits, the scoring should be unchanged.
#[test] #[test]
fn test_no_shared_cache_unchanged() { fn test_no_shared_cache_unchanged() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig; use crate::test_utils::SimpleWorkerConfig;
let block_size = 16u32; let block_size = 16u32;
let isl = 64usize; let isl = 64usize;
let worker0 = WorkerWithDpRank::from_worker_id(0); let worker0 = WorkerWithDpRank::from_worker_id(0);
let mut overlaps = OverlapScores::new(); let mut effective_overlap_blocks = HashMap::new();
overlaps.scores.insert(worker0, 2); effective_overlap_blocks.insert(worker0, 2.0);
let config = KvRouterConfig::default(); let config = KvRouterConfig::default();
let selector = DefaultWorkerSelector::new(Some(config), "test"); let selector = DefaultWorkerSelector::new(Some(config), "test");
...@@ -560,7 +601,10 @@ mod tests { ...@@ -560,7 +601,10 @@ mod tests {
maybe_request_id: Some("test".into()), maybe_request_id: Some("test".into()),
token_seq: None, token_seq: None,
isl_tokens: isl, isl_tokens: isl,
overlaps, tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
......
...@@ -8,9 +8,13 @@ use rustc_hash::FxHashMap; ...@@ -8,9 +8,13 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::config::RouterConfigOverride; use super::config::RouterConfigOverride;
use crate::protocols::{ use crate::protocols::{DpRank, SharedCacheHits, WorkerConfigLike, WorkerId, WorkerWithDpRank};
DpRank, OverlapScores, SharedCacheHits, WorkerConfigLike, WorkerId, WorkerWithDpRank,
}; #[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TierOverlapBlocks {
pub host_pinned: HashMap<WorkerWithDpRank, usize>,
pub disk: HashMap<WorkerWithDpRank, usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad { pub struct PotentialLoad {
...@@ -38,14 +42,20 @@ pub enum KvSchedulerError { ...@@ -38,14 +42,20 @@ pub enum KvSchedulerError {
#[derive(Debug)] #[derive(Debug)]
pub struct SchedulingResponse { pub struct SchedulingResponse {
pub best_worker: WorkerWithDpRank, pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32, pub effective_overlap_blocks: f64,
pub cached_tokens: usize,
} }
pub struct SchedulingRequest { pub struct SchedulingRequest {
pub maybe_request_id: Option<String>, pub maybe_request_id: Option<String>,
pub token_seq: Option<Vec<SequenceHash>>, pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize, pub isl_tokens: usize,
pub overlaps: OverlapScores, pub tier_overlap_blocks: TierOverlapBlocks,
pub effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
pub effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
/// Per-worker tree size, used only for tie-breaking when multiple workers
/// produce the same logit at temperature=0.
pub tree_sizes: HashMap<WorkerWithDpRank, usize>,
pub decode_blocks: FxHashMap<WorkerWithDpRank, usize>, pub decode_blocks: FxHashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: FxHashMap<WorkerWithDpRank, usize>, pub prefill_tokens: FxHashMap<WorkerWithDpRank, usize>,
pub track_prefill_tokens: bool, pub track_prefill_tokens: bool,
...@@ -85,16 +95,6 @@ impl SchedulingRequest { ...@@ -85,16 +95,6 @@ impl SchedulingRequest {
}) })
} }
/// Scheduling consumers use the exact pinned-worker overlap when present;
/// otherwise they use the best available overlap across eligible workers.
pub fn overlap_blocks(&self) -> u32 {
if let Some(worker) = self.pinned_worker {
return self.overlaps.scores.get(&worker).copied().unwrap_or(0);
}
self.overlaps.scores.values().copied().max().unwrap_or(0)
}
pub fn bypass_capacity_check(&self) -> bool { pub fn bypass_capacity_check(&self) -> bool {
self.pinned_worker.is_none() && self.allowed_worker_ids.is_some() self.pinned_worker.is_none() && self.allowed_worker_ids.is_some()
} }
......
This diff is collapsed.
...@@ -51,6 +51,7 @@ impl PrefillLoadSnapshot { ...@@ -51,6 +51,7 @@ impl PrefillLoadSnapshot {
} }
} }
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn added_prefill_tokens(block_size: usize, isl: usize, overlap: u32) -> usize { pub(super) fn added_prefill_tokens(block_size: usize, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * block_size; let cached_tokens = (overlap as usize) * block_size;
isl.checked_sub(cached_tokens).unwrap_or_else(|| { isl.checked_sub(cached_tokens).unwrap_or_else(|| {
......
...@@ -534,6 +534,7 @@ impl PromptMembershipTrie { ...@@ -534,6 +534,7 @@ impl PromptMembershipTrie {
} }
} }
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn compute_overlap_depths( pub(super) fn compute_overlap_depths(
&self, &self,
query: Option<&[SequenceHash]>, query: Option<&[SequenceHash]>,
......
...@@ -97,6 +97,7 @@ impl PromptRegistry { ...@@ -97,6 +97,7 @@ impl PromptRegistry {
} }
#[expect(clippy::too_many_arguments)] #[expect(clippy::too_many_arguments)]
#[cfg_attr(not(test), allow(dead_code))]
fn project_loads_from_overlap( fn project_loads_from_overlap(
&self, &self,
query_len: usize, query_len: usize,
...@@ -135,6 +136,7 @@ impl PromptRegistry { ...@@ -135,6 +136,7 @@ impl PromptRegistry {
(potential_blocks, potential_tokens) (potential_blocks, potential_tokens)
} }
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn potential_blocks_and_tokens_with_prefill_tracking( pub(super) fn potential_blocks_and_tokens_with_prefill_tracking(
&self, &self,
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
......
...@@ -109,8 +109,6 @@ pub struct ActiveSequences { ...@@ -109,8 +109,6 @@ pub struct ActiveSequences {
requests: HashMap<RequestId, RequestState>, requests: HashMap<RequestId, RequestState>,
prefill: PrefillLoadTracker, prefill: PrefillLoadTracker,
blocks: BlockTracker, blocks: BlockTracker,
#[cfg(test)]
block_size: usize,
last_expiry_check_time: Instant, last_expiry_check_time: Instant,
} }
...@@ -123,8 +121,6 @@ impl ActiveSequences { ...@@ -123,8 +121,6 @@ impl ActiveSequences {
requests: HashMap::new(), requests: HashMap::new(),
prefill: PrefillLoadTracker::default(), prefill: PrefillLoadTracker::default(),
blocks: BlockTracker::default(), blocks: BlockTracker::default(),
#[cfg(test)]
block_size,
last_expiry_check_time: Instant::now(), last_expiry_check_time: Instant::now(),
} }
} }
...@@ -157,14 +153,12 @@ impl ActiveSequences { ...@@ -157,14 +153,12 @@ impl ActiveSequences {
self.blocks.active_blocks() self.blocks.active_blocks()
} }
#[cfg(test)]
pub(super) fn active_tokens(&self, decay_now: Instant) -> usize { pub(super) fn active_tokens(&self, decay_now: Instant) -> usize {
self.prefill.snapshot().active_tokens_at(decay_now) self.prefill.snapshot().active_tokens_at(decay_now)
} }
/// Add a new request with optional prompt-token load accounting. /// Add a new request with optional prompt-token load accounting.
/// Returns block membership transitions plus any expired request IDs removed during cleanup. /// Returns block membership transitions plus any expired request IDs removed during cleanup.
#[allow(clippy::too_many_arguments)]
pub(super) fn add_request_with_prefill_tracking( pub(super) fn add_request_with_prefill_tracking(
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
...@@ -299,6 +293,31 @@ impl ActiveSequences { ...@@ -299,6 +293,31 @@ impl ActiveSequences {
membership_delta membership_delta
} }
pub fn new_tokens(&self, isl: usize, cached_tokens: usize) -> usize {
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens}, returning 0"
);
0
})
}
pub fn potential_blocks_and_tokens(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
cached_tokens: usize,
decay_now: Instant,
) -> (usize, usize) {
self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
cached_tokens,
true,
decay_now,
)
}
/// Add an output block with a random hash and optional fractional decay weight. /// Add an output block with a random hash and optional fractional decay weight.
/// ///
/// This is used during generation to track output blocks as they are created. /// This is used during generation to track output blocks as they are created.
...@@ -330,12 +349,11 @@ impl ActiveSequences { ...@@ -330,12 +349,11 @@ impl ActiveSequences {
acquire.became_present_on_worker.then_some(random_hash) acquire.became_present_on_worker.then_some(random_hash)
} }
#[cfg(test)] pub fn potential_blocks_and_tokens_with_prefill_tracking(
fn potential_blocks_and_tokens_with_prefill_tracking(
&self, &self,
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
isl: usize, isl: usize,
overlap: u32, cached_tokens: usize,
track_prefill_tokens: bool, track_prefill_tokens: bool,
decay_now: Instant, decay_now: Instant,
) -> (usize, usize) { ) -> (usize, usize) {
...@@ -346,7 +364,7 @@ impl ActiveSequences { ...@@ -346,7 +364,7 @@ impl ActiveSequences {
}; };
let active_tokens = self.active_tokens(decay_now); let active_tokens = self.active_tokens(decay_now);
let potential_tokens = if track_prefill_tokens { let potential_tokens = if track_prefill_tokens {
added_prefill_tokens(self.block_size, isl, overlap) + active_tokens self.new_tokens(isl, cached_tokens) + active_tokens
} else { } else {
active_tokens active_tokens
}; };
......
This diff is collapsed.
...@@ -230,6 +230,7 @@ impl KvBlockManagerConfigBuilder { ...@@ -230,6 +230,7 @@ impl KvBlockManagerConfigBuilder {
engine_endpoint: String, engine_endpoint: String,
output_endpoint: Option<String>, output_endpoint: Option<String>,
engine_source: crate::block_manager::kv_consolidator::EventSource, engine_source: crate::block_manager::kv_consolidator::EventSource,
mode: crate::block_manager::kv_consolidator::KvEventConsolidationMode,
) -> Self { ) -> Self {
let config = match engine_source { let config = match engine_source {
crate::block_manager::kv_consolidator::EventSource::Vllm => { crate::block_manager::kv_consolidator::EventSource::Vllm => {
...@@ -237,6 +238,7 @@ impl KvBlockManagerConfigBuilder { ...@@ -237,6 +238,7 @@ impl KvBlockManagerConfigBuilder {
crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_vllm( crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_vllm(
engine_endpoint, engine_endpoint,
output_ep, output_ep,
mode,
) )
} }
crate::block_manager::kv_consolidator::EventSource::Trtllm => { crate::block_manager::kv_consolidator::EventSource::Trtllm => {
...@@ -248,6 +250,7 @@ impl KvBlockManagerConfigBuilder { ...@@ -248,6 +250,7 @@ impl KvBlockManagerConfigBuilder {
crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_trtllm( crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_trtllm(
engine_endpoint, engine_endpoint,
output_ep, output_ep,
mode,
) )
} }
crate::block_manager::kv_consolidator::EventSource::Kvbm => { crate::block_manager::kv_consolidator::EventSource::Kvbm => {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment