Unverified Commit e3d00b89 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Tier-based KV Routing (#8380)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7e48f3bd
...@@ -44,6 +44,7 @@ fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) { ...@@ -44,6 +44,7 @@ fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) {
} }
mod kv_indexer; mod kv_indexer;
mod local; mod local;
mod lower_tier;
mod metrics; mod metrics;
mod thread_pool; mod thread_pool;
mod traits; mod traits;
...@@ -62,6 +63,7 @@ mod tests; ...@@ -62,6 +63,7 @@ mod tests;
pub use branch_sharded::*; pub use branch_sharded::*;
pub use kv_indexer::*; pub use kv_indexer::*;
pub use local::*; pub use local::*;
pub use lower_tier::*;
pub use metrics::*; pub use metrics::*;
pub use thread_pool::*; pub use thread_pool::*;
pub use traits::*; pub use traits::*;
......
...@@ -23,7 +23,7 @@ use std::{ ...@@ -23,7 +23,7 @@ use std::{
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use super::{EventWarningKind, PreBoundEventCounters}; use super::{EventWarningKind, MatchDetails, PreBoundEventCounters};
use crate::active_set::reconcile_active_workers; use crate::active_set::reconcile_active_workers;
use crate::protocols::*; use crate::protocols::*;
...@@ -162,12 +162,20 @@ impl RadixTree { ...@@ -162,12 +162,20 @@ impl RadixTree {
/// ///
/// ### Returns /// ### Returns
/// ///
/// An `OverlapScores` representing the match scores. /// A `MatchDetails` representing overlap scores plus continuation state.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores { pub fn find_match_details(
let mut scores = OverlapScores::new(); &self,
sequence: Vec<LocalBlockHash>,
early_exit: bool,
) -> MatchDetails {
let mut details = MatchDetails::new();
let MatchDetails {
overlap_scores: scores,
last_matched_hashes,
} = &mut details;
if sequence.is_empty() { if sequence.is_empty() {
return scores; return details;
} }
let now = Instant::now(); let now = Instant::now();
...@@ -184,7 +192,7 @@ impl RadixTree { ...@@ -184,7 +192,7 @@ impl RadixTree {
}; };
let Some(first_child) = first_child else { let Some(first_child) = first_child else {
return scores; return details;
}; };
// Initialize active worker set from first child. // Initialize active worker set from first child.
...@@ -208,12 +216,18 @@ impl RadixTree { ...@@ -208,12 +216,18 @@ impl RadixTree {
} }
if active.is_empty() { if active.is_empty() {
return scores; return details;
} }
let mut current_hash = first_child
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
if early_exit && active_count == 1 { if early_exit && active_count == 1 {
for worker in &active { for worker in &active {
scores.scores.insert(*worker, 1); scores.scores.insert(*worker, 1);
last_matched_hashes.insert(*worker, current_hash);
} }
for worker in scores.scores.keys() { for worker in scores.scores.keys() {
let tree_size = self let tree_size = self
...@@ -223,7 +237,7 @@ impl RadixTree { ...@@ -223,7 +237,7 @@ impl RadixTree {
.len(); .len();
scores.tree_sizes.insert(*worker, tree_size); scores.tree_sizes.insert(*worker, tree_size);
} }
return scores; return details;
} }
let mut current = first_child; let mut current = first_child;
...@@ -256,6 +270,7 @@ impl RadixTree { ...@@ -256,6 +270,7 @@ impl RadixTree {
if child_count != active_count { if child_count != active_count {
reconcile_active_workers(&mut active, &borrow.workers, |worker| { reconcile_active_workers(&mut active, &borrow.workers, |worker| {
scores.scores.insert(worker, matched_depth); scores.scores.insert(worker, matched_depth);
last_matched_hashes.insert(worker, current_hash);
}); });
active_count = active.len(); active_count = active.len();
} }
...@@ -281,9 +296,17 @@ impl RadixTree { ...@@ -281,9 +296,17 @@ impl RadixTree {
if early_exit && active_count == 1 { if early_exit && active_count == 1 {
matched_depth = (idx + 1) as u32; matched_depth = (idx + 1) as u32;
current_hash = block
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
break; break;
} }
current_hash = block
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
current = block; current = block;
matched_depth = (idx + 1) as u32; matched_depth = (idx + 1) as u32;
} }
...@@ -291,6 +314,7 @@ impl RadixTree { ...@@ -291,6 +314,7 @@ impl RadixTree {
// Record scores for workers that survived through the deepest matched level. // Record scores for workers that survived through the deepest matched level.
for worker in &active { for worker in &active {
scores.scores.insert(*worker, matched_depth); scores.scores.insert(*worker, matched_depth);
last_matched_hashes.insert(*worker, current_hash);
} }
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores); tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
...@@ -305,7 +329,12 @@ impl RadixTree { ...@@ -305,7 +329,12 @@ impl RadixTree {
scores.tree_sizes.insert(*worker, tree_size); scores.tree_sizes.insert(*worker, tree_size);
} }
scores details
}
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
self.find_match_details(sequence, early_exit).overlap_scores
} }
/// Apply a [`RouterEvent`] to the radix tree. /// Apply a [`RouterEvent`] to the radix tree.
......
...@@ -9,6 +9,7 @@ use tokio::sync::oneshot; ...@@ -9,6 +9,7 @@ use tokio::sync::oneshot;
use crate::protocols::*; use crate::protocols::*;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use rustc_hash::FxHashMap;
/// Trait for types that may represent an error response. /// Trait for types that may represent an error response.
/// Used for RPC-style responses that can indicate success or failure. /// Used for RPC-style responses that can indicate success or failure.
...@@ -250,6 +251,21 @@ impl dynamo_runtime::protocols::maybe_error::MaybeError for IndexerRecordRouting ...@@ -250,6 +251,21 @@ impl dynamo_runtime::protocols::maybe_error::MaybeError for IndexerRecordRouting
} }
} }
/// Rich non-wire query result for router-local device tier lookups.
#[derive(Debug, Clone, Default)]
pub struct MatchDetails {
/// Existing overlap scores used by scheduling.
pub overlap_scores: OverlapScores,
/// Last matched device sequence hash per worker, used to seed lower-tier queries.
pub last_matched_hashes: FxHashMap<WorkerWithDpRank, ExternalSequenceBlockHash>,
}
impl MatchDetails {
pub fn new() -> Self {
Self::default()
}
}
/// A request to find matches in the Radix Tree. /// A request to find matches in the Radix Tree.
pub struct MatchRequest { pub struct MatchRequest {
/// A vector of `LocalBlockHash` representing the sequence to match. /// A vector of `LocalBlockHash` representing the sequence to match.
...@@ -279,6 +295,30 @@ impl MatchRequest { ...@@ -279,6 +295,30 @@ impl MatchRequest {
} }
} }
/// A request to find matches while also returning continuation metadata.
pub struct MatchDetailsRequest {
/// A vector of `LocalBlockHash` representing the sequence to match.
pub sequence: Vec<LocalBlockHash>,
/// A boolean indicating whether to exit early if a single match is found.
pub early_exit: bool,
/// A channel sender to send the `MatchDetails` response.
pub resp: oneshot::Sender<MatchDetails>,
}
impl MatchDetailsRequest {
pub(super) fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: oneshot::Sender<MatchDetails>,
) -> Self {
Self {
sequence,
early_exit,
resp,
}
}
}
/// A request to dump the tree as events /// A request to dump the tree as events
pub struct DumpRequest { pub struct DumpRequest {
/// Channel to send the dumped events /// Channel to send the dumped events
......
...@@ -51,7 +51,8 @@ pub use config::{ ...@@ -51,7 +51,8 @@ pub use config::{
SharedCacheType, SharedCacheType,
}; };
pub use indexer::{ pub use indexer::{
BranchShardedIndexer, MaybeError, SharedKvCache, SyncIndexer, ThreadPoolIndexer, BranchShardedIndexer, LowerTierContinuation, LowerTierIndexer, MaybeError, SharedKvCache,
SyncIndexer, ThreadPoolIndexer,
}; };
pub use nested_map::PositionalIndexer; pub use nested_map::PositionalIndexer;
pub use protocols::{ pub use protocols::{
......
...@@ -355,9 +355,12 @@ pub struct WorkerSelectionResult { ...@@ -355,9 +355,12 @@ pub struct WorkerSelectionResult {
/// The total number of blocks required to prefill the request /// The total number of blocks required to prefill the request
pub required_blocks: u64, pub required_blocks: u64,
/// The number of blocks that the selected worker may already have cached. /// Approximate effective cache hit on the selected worker in fractional blocks.
/// This is not a guarantee, but an estimate. /// Use `.round() as u32` for a block-count approximation.
pub overlap_blocks: u32, pub effective_overlap_blocks: f64,
/// Approximate cached-token count derived from the weighted cache hit.
pub cached_tokens: usize,
} }
/// Active load metrics for a worker, used for busy detection. /// Active load metrics for a worker, used for busy detection.
......
...@@ -35,6 +35,14 @@ pub fn min_initial_workers_from_env() -> anyhow::Result<usize> { ...@@ -35,6 +35,14 @@ pub fn min_initial_workers_from_env() -> anyhow::Result<usize> {
} }
} }
const fn default_host_cache_hit_weight() -> f64 {
0.75
}
const fn default_disk_cache_hit_weight() -> f64 {
0.25
}
/// Type of external shared KV cache to query during routing. /// Type of external shared KV cache to query during routing.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
...@@ -170,6 +178,14 @@ pub struct KvRouterConfig { ...@@ -170,6 +178,14 @@ pub struct KvRouterConfig {
#[validate(range(min = 0.0))] #[validate(range(min = 0.0))]
pub overlap_score_weight: f64, pub overlap_score_weight: f64,
#[serde(default = "default_host_cache_hit_weight")]
#[validate(range(min = 0.0, max = 1.0))]
pub host_cache_hit_weight: f64,
#[serde(default = "default_disk_cache_hit_weight")]
#[validate(range(min = 0.0, max = 1.0))]
pub disk_cache_hit_weight: f64,
#[validate(range(min = 0.0))] #[validate(range(min = 0.0))]
pub router_temperature: f64, pub router_temperature: f64,
...@@ -269,6 +285,8 @@ impl Default for KvRouterConfig { ...@@ -269,6 +285,8 @@ impl Default for KvRouterConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
overlap_score_weight: 1.0, overlap_score_weight: 1.0,
host_cache_hit_weight: default_host_cache_hit_weight(),
disk_cache_hit_weight: default_disk_cache_hit_weight(),
router_temperature: 0.0, router_temperature: 0.0,
use_kv_events: true, use_kv_events: true,
durable_kv_events: false, // default to NATS Core (local indexer mode) durable_kv_events: false, // default to NATS Core (local indexer mode)
......
...@@ -14,8 +14,10 @@ use super::policy::{RouterSchedulingPolicy, SchedulingPolicy}; ...@@ -14,8 +14,10 @@ use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::prefill_load::PrefillLoadEstimator; use super::prefill_load::PrefillLoadEstimator;
use super::queue::SchedulerQueue; use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector}; use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse}; use super::types::{
use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank}; KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse, TierOverlapBlocks,
};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{ use crate::sequences::{
ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest, ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
}; };
...@@ -42,6 +44,18 @@ where ...@@ -42,6 +44,18 @@ where
S: SchedulingPolicy + 'static, S: SchedulingPolicy + 'static,
Sel: WorkerSelector<C> + Send + Sync + 'static, Sel: WorkerSelector<C> + Send + Sync + 'static,
{ {
fn worker_dp_range(workers: &HashMap<WorkerId, C>) -> HashMap<WorkerId, (u32, u32)> {
workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect()
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>, slots: Arc<ActiveSequencesMultiWorker<P>>,
...@@ -84,15 +98,7 @@ where ...@@ -84,15 +98,7 @@ where
continue; continue;
} }
let dp_range: HashMap<WorkerId, (u32, u32)> = current_workers let dp_range = Self::worker_dp_range(&current_workers);
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect();
slots_monitor.update_workers(&dp_range); slots_monitor.update_workers(&dp_range);
last_workers = current_workers; last_workers = current_workers;
} }
...@@ -168,7 +174,10 @@ where ...@@ -168,7 +174,10 @@ where
maybe_request_id: Option<String>, maybe_request_id: Option<String>,
isl_tokens: usize, isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores, tier_overlap_blocks: TierOverlapBlocks,
effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
tree_sizes: HashMap<WorkerWithDpRank, usize>,
router_config_override: Option<&super::config::RouterConfigOverride>, router_config_override: Option<&super::config::RouterConfigOverride>,
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
...@@ -186,7 +195,10 @@ where ...@@ -186,7 +195,10 @@ where
maybe_request_id, maybe_request_id,
token_seq, token_seq,
isl_tokens, isl_tokens,
overlaps, tier_overlap_blocks,
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes,
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens, track_prefill_tokens,
...@@ -258,7 +270,7 @@ where ...@@ -258,7 +270,7 @@ where
&self, &self,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize, isl_tokens: usize,
overlaps: OverlapScores, effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
track_prefill_tokens: bool, track_prefill_tokens: bool,
) -> Vec<PotentialLoad> { ) -> Vec<PotentialLoad> {
let decay_now = Instant::now(); let decay_now = Instant::now();
...@@ -267,7 +279,7 @@ where ...@@ -267,7 +279,7 @@ where
.potential_blocks_and_tokens_with_prefill_tracking( .potential_blocks_and_tokens_with_prefill_tracking(
token_seq.as_deref(), token_seq.as_deref(),
isl_tokens, isl_tokens,
overlaps, effective_cached_tokens,
track_prefill_tokens, track_prefill_tokens,
decay_now, decay_now,
); );
...@@ -306,7 +318,7 @@ mod tests { ...@@ -306,7 +318,7 @@ mod tests {
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use super::*; use super::*;
use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores}; use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData};
use crate::scheduling::PrefillLoadEstimator; use crate::scheduling::PrefillLoadEstimator;
use crate::scheduling::policy::FcfsPolicy; use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector; use crate::scheduling::selector::DefaultWorkerSelector;
...@@ -423,7 +435,10 @@ mod tests { ...@@ -423,7 +435,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
Some("adapter-a".to_string()), Some("adapter-a".to_string()),
...@@ -462,7 +477,10 @@ mod tests { ...@@ -462,7 +477,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
Some(&crate::config::RouterConfigOverride { Some(&crate::config::RouterConfigOverride {
track_prefill_tokens: Some(false), track_prefill_tokens: Some(false),
..Default::default() ..Default::default()
...@@ -489,6 +507,52 @@ mod tests { ...@@ -489,6 +507,52 @@ mod tests {
cancel_token.cancel(); cancel_token.cancel();
} }
#[tokio::test]
async fn test_schedule_uses_weighted_cached_tokens_for_active_tracking() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let worker = WorkerWithDpRank::new(0, 0);
let response = scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
TierOverlapBlocks::default(),
HashMap::from([(worker, 0.75)]),
HashMap::from([(worker, 48)]),
HashMap::new(),
None,
true,
None,
0.0,
None,
None,
None,
None,
)
.await
.unwrap();
assert_eq!(response.best_worker, worker);
assert_eq!(response.cached_tokens, 48);
assert_eq!(response.effective_overlap_blocks, 0.75);
assert_eq!(
slots.active_tokens(Instant::now()).get(&worker).copied(),
Some(16),
"weighted cached tokens should reduce tracked prefill load",
);
cancel_token.cancel();
}
#[tokio::test] #[tokio::test]
async fn test_mark_prefill_completed_drains_pending_queue() { async fn test_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new(); let mut workers = HashMap::new();
...@@ -507,7 +571,10 @@ mod tests { ...@@ -507,7 +571,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -528,7 +595,10 @@ mod tests { ...@@ -528,7 +595,10 @@ mod tests {
Some("req-2".to_string()), Some("req-2".to_string()),
64, 64,
Some(vec![5, 6, 7, 8]), Some(vec![5, 6, 7, 8]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -570,7 +640,10 @@ mod tests { ...@@ -570,7 +640,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -591,7 +664,10 @@ mod tests { ...@@ -591,7 +664,10 @@ mod tests {
Some("req-2".to_string()), Some("req-2".to_string()),
64, 64,
Some(vec![5, 6, 7, 8]), Some(vec![5, 6, 7, 8]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -647,7 +723,10 @@ mod tests { ...@@ -647,7 +723,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -668,7 +747,10 @@ mod tests { ...@@ -668,7 +747,10 @@ mod tests {
Some("req-2".to_string()), Some("req-2".to_string()),
64, 64,
Some(vec![5, 6, 7, 8]), Some(vec![5, 6, 7, 8]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -723,7 +805,10 @@ mod tests { ...@@ -723,7 +805,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -744,7 +829,10 @@ mod tests { ...@@ -744,7 +829,10 @@ mod tests {
Some("req-2".to_string()), Some("req-2".to_string()),
64, 64,
Some(vec![5, 6, 7, 8]), Some(vec![5, 6, 7, 8]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -797,7 +885,10 @@ mod tests { ...@@ -797,7 +885,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
Some("adapter-a".to_string()), Some("adapter-a".to_string()),
...@@ -839,14 +930,10 @@ mod tests { ...@@ -839,14 +930,10 @@ mod tests {
); );
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let token_seq = vec![11, 22, 33, 44]; let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default(); let cached_tokens = HashMap::new();
let (decode_blocks, prefill_tokens) = slots.potential_blocks_and_tokens( let (decode_blocks, prefill_tokens) =
Some(&token_seq), slots.potential_blocks_and_tokens(Some(&token_seq), 128, cached_tokens.clone());
128,
overlaps.clone(),
Instant::now(),
);
let mut expected: Vec<_> = decode_blocks let mut expected: Vec<_> = decode_blocks
.keys() .keys()
.map(|worker| PotentialLoad { .map(|worker| PotentialLoad {
...@@ -858,7 +945,7 @@ mod tests { ...@@ -858,7 +945,7 @@ mod tests {
.collect(); .collect();
expected.sort_by_key(|load| (load.worker_id, load.dp_rank)); expected.sort_by_key(|load| (load.worker_id, load.dp_rank));
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps, true); let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, cached_tokens, true);
actual.sort_by_key(|load| (load.worker_id, load.dp_rank)); actual.sort_by_key(|load| (load.worker_id, load.dp_rank));
assert_eq!(actual.len(), expected.len()); assert_eq!(actual.len(), expected.len());
...@@ -899,7 +986,10 @@ mod tests { ...@@ -899,7 +986,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
100, 100,
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -914,7 +1004,7 @@ mod tests { ...@@ -914,7 +1004,7 @@ mod tests {
tokio::time::advance(Duration::from_secs(6)).await; tokio::time::advance(Duration::from_secs(6)).await;
let loads = scheduler.get_potential_loads(None, 0, OverlapScores::default(), true); let loads = scheduler.get_potential_loads(None, 0, HashMap::new(), true);
assert_eq!(loads.len(), 1); assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 40); assert_eq!(loads[0].potential_prefill_tokens, 40);
...@@ -927,7 +1017,7 @@ mod tests { ...@@ -927,7 +1017,7 @@ mod tests {
make_scheduler(HashMap::new(), None, false, None); make_scheduler(HashMap::new(), None, false, None);
scheduler.register_workers(&HashSet::from([42])); scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), true); let loads = scheduler.get_potential_loads(None, 64, HashMap::new(), true);
assert_eq!(loads.len(), 1); assert_eq!(loads.len(), 1);
assert_eq!(loads[0].worker_id, 42); assert_eq!(loads[0].worker_id, 42);
...@@ -944,7 +1034,7 @@ mod tests { ...@@ -944,7 +1034,7 @@ mod tests {
assert_eq!( assert_eq!(
scheduler scheduler
.get_potential_loads(None, 64, OverlapScores::default(), true) .get_potential_loads(None, 64, HashMap::new(), true,)
.len(), .len(),
1 1
); );
...@@ -963,7 +1053,7 @@ mod tests { ...@@ -963,7 +1053,7 @@ mod tests {
tokio::time::timeout(Duration::from_secs(1), async { tokio::time::timeout(Duration::from_secs(1), async {
loop { loop {
if scheduler if scheduler
.get_potential_loads(None, 64, OverlapScores::default(), true) .get_potential_loads(None, 64, HashMap::new(), true)
.len() .len()
== 3 == 3
{ {
...@@ -995,7 +1085,10 @@ mod tests { ...@@ -995,7 +1085,10 @@ mod tests {
Some("req-1".to_string()), Some("req-1".to_string()),
64, 64,
Some(vec![11, 22]), Some(vec![11, 22]),
OverlapScores::default(), TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None, None,
true, true,
None, None,
...@@ -1008,7 +1101,7 @@ mod tests { ...@@ -1008,7 +1101,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), false); let loads = scheduler.get_potential_loads(None, 64, HashMap::new(), false);
assert_eq!(loads.len(), 1); assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 64); assert_eq!(loads[0].potential_prefill_tokens, 64);
......
...@@ -66,16 +66,34 @@ impl SchedulingPolicy for LcfsPolicy { ...@@ -66,16 +66,34 @@ impl SchedulingPolicy for LcfsPolicy {
/// Optimizes for average TTFT — minimizes total weighted completion time /// Optimizes for average TTFT — minimizes total weighted completion time
/// (Smith 1956). Short or high-priority requests are scheduled before /// (Smith 1956). Short or high-priority requests are scheduled before
/// long low-priority ones, reducing mean latency across the batch. /// long low-priority ones, reducing mean latency across the batch.
pub struct WsptPolicy { pub struct WsptPolicy;
pub block_size: usize,
}
impl SchedulingPolicy for WsptPolicy { impl SchedulingPolicy for WsptPolicy {
type Key = OrderedFloat<f64>; type Key = OrderedFloat<f64>;
fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key { fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
let weight = 1.0 + request.priority_jump.max(0.0); let weight = 1.0 + request.priority_jump.max(0.0);
let cached_tokens = request.overlap_blocks() as usize * self.block_size; let allowed_ids = request.allowed_worker_ids.as_ref();
let cached_tokens = request.pinned_worker.map_or_else(
|| {
request
.effective_cached_tokens
.iter()
.filter(|(worker, _)| {
allowed_ids.is_none_or(|ids| ids.contains(&worker.worker_id))
})
.map(|(_, tokens)| *tokens)
.max()
.unwrap_or(0)
},
|worker| {
request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0)
},
);
let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1); let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1);
OrderedFloat(weight / new_tokens as f64) OrderedFloat(weight / new_tokens as f64)
} }
...@@ -91,11 +109,11 @@ pub enum RouterSchedulingPolicy { ...@@ -91,11 +109,11 @@ pub enum RouterSchedulingPolicy {
} }
impl RouterSchedulingPolicy { impl RouterSchedulingPolicy {
pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self { pub fn new(kind: RouterQueuePolicy) -> Self {
match kind { match kind {
RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy), RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy), RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }), RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy),
} }
} }
} }
...@@ -124,11 +142,24 @@ mod tests { ...@@ -124,11 +142,24 @@ mod tests {
priority_jump: f64, priority_jump: f64,
overlaps: OverlapScores, overlaps: OverlapScores,
) -> SchedulingRequest { ) -> SchedulingRequest {
let effective_overlap_blocks = overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as f64))
.collect();
let effective_cached_tokens = overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as usize * 16))
.collect();
SchedulingRequest { SchedulingRequest {
maybe_request_id: None, maybe_request_id: None,
token_seq: None, token_seq: None,
isl_tokens, isl_tokens,
overlaps, tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes: std::collections::HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
...@@ -224,10 +255,10 @@ mod tests { ...@@ -224,10 +255,10 @@ mod tests {
let early = Duration::from_secs(1); let early = Duration::from_secs(1);
let late = Duration::from_secs(10); let late = Duration::from_secs(10);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs, 16); let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs);
assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req)); assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req));
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs, 16); let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs);
assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req)); assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req));
} }
...@@ -235,7 +266,7 @@ mod tests { ...@@ -235,7 +266,7 @@ mod tests {
#[test] #[test]
fn wspt_shorter_request_scheduled_first() { fn wspt_shorter_request_scheduled_first() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let short = request_with(100, 0.0, OverlapScores::default()); let short = request_with(100, 0.0, OverlapScores::default());
let long = request_with(1000, 0.0, OverlapScores::default()); let long = request_with(1000, 0.0, OverlapScores::default());
let t = Duration::ZERO; let t = Duration::ZERO;
...@@ -247,7 +278,7 @@ mod tests { ...@@ -247,7 +278,7 @@ mod tests {
#[test] #[test]
fn wspt_overlap_reduces_effective_cost() { fn wspt_overlap_reduces_effective_cost() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
// Both 1024 ISL tokens, but one has 60 blocks cached (960 tokens). // Both 1024 ISL tokens, but one has 60 blocks cached (960 tokens).
let no_cache = request_with(1024, 0.0, OverlapScores::default()); let no_cache = request_with(1024, 0.0, OverlapScores::default());
let cached = request_with(1024, 0.0, overlaps_from(&[(0, 60)])); let cached = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
...@@ -262,7 +293,7 @@ mod tests { ...@@ -262,7 +293,7 @@ mod tests {
#[test] #[test]
fn wspt_priority_promotes() { fn wspt_priority_promotes() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let normal = request_with(512, 0.0, OverlapScores::default()); let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 5.0, OverlapScores::default()); let boosted = request_with(512, 5.0, OverlapScores::default());
let t = Duration::ZERO; let t = Duration::ZERO;
...@@ -274,7 +305,7 @@ mod tests { ...@@ -274,7 +305,7 @@ mod tests {
#[test] #[test]
fn wspt_uses_max_overlap() { fn wspt_uses_max_overlap() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
// 4 workers with overlaps [10, 20, 50, 60]. max = 60. // 4 workers with overlaps [10, 20, 50, 60]. max = 60.
// new_tokens = 1024 - 60*16 = 64 // new_tokens = 1024 - 60*16 = 64
let req = request_with( let req = request_with(
...@@ -289,7 +320,7 @@ mod tests { ...@@ -289,7 +320,7 @@ mod tests {
#[test] #[test]
fn wspt_uses_pinned_worker_overlap_when_present() { fn wspt_uses_pinned_worker_overlap_when_present() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60), (1, 1)])); let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60), (1, 1)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0)); req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
...@@ -300,7 +331,7 @@ mod tests { ...@@ -300,7 +331,7 @@ mod tests {
#[test] #[test]
fn wspt_missing_pinned_overlap_uses_zero() { fn wspt_missing_pinned_overlap_uses_zero() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60)])); let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0)); req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
...@@ -311,7 +342,7 @@ mod tests { ...@@ -311,7 +342,7 @@ mod tests {
#[test] #[test]
fn wspt_no_overlap_falls_back_to_isl() { fn wspt_no_overlap_falls_back_to_isl() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
let req = request_with(512, 0.0, OverlapScores::default()); let req = request_with(512, 0.0, OverlapScores::default());
let key = policy.enqueue_key(Duration::ZERO, &req); let key = policy.enqueue_key(Duration::ZERO, &req);
let expected = OrderedFloat(1.0 / 512.0); let expected = OrderedFloat(1.0 / 512.0);
...@@ -320,7 +351,7 @@ mod tests { ...@@ -320,7 +351,7 @@ mod tests {
#[test] #[test]
fn wspt_full_overlap_clamps_to_one() { fn wspt_full_overlap_clamps_to_one() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy;
// 512 tokens, 64 blocks cached = 1024 cached tokens > ISL → saturating_sub → 0 → max(1) // 512 tokens, 64 blocks cached = 1024 cached tokens > ISL → saturating_sub → 0 → max(1)
let req = request_with(512, 0.0, overlaps_from(&[(0, 64)])); let req = request_with(512, 0.0, overlaps_from(&[(0, 64)]));
let key = policy.enqueue_key(Duration::ZERO, &req); let key = policy.enqueue_key(Duration::ZERO, &req);
......
...@@ -239,7 +239,7 @@ impl< ...@@ -239,7 +239,7 @@ impl<
.potential_blocks_and_tokens_with_prefill_tracking( .potential_blocks_and_tokens_with_prefill_tracking(
request.token_seq.as_deref(), request.token_seq.as_deref(),
request.isl_tokens, request.isl_tokens,
request.overlaps.clone(), request.effective_cached_tokens.clone(),
request.track_prefill_tokens, request.track_prefill_tokens,
decay_now, decay_now,
); );
...@@ -263,7 +263,8 @@ impl< ...@@ -263,7 +263,8 @@ impl<
request.respond(Ok(SchedulingResponse { request.respond(Ok(SchedulingResponse {
best_worker: selection.worker, best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks, effective_overlap_blocks: selection.effective_overlap_blocks,
cached_tokens: selection.cached_tokens,
})); }));
if !request.update_states { if !request.update_states {
...@@ -277,7 +278,7 @@ impl< ...@@ -277,7 +278,7 @@ impl<
let prefill_load_hint = self.prefill_load_hint_for( let prefill_load_hint = self.prefill_load_hint_for(
request.isl_tokens, request.isl_tokens,
selection.overlap_blocks, selection.cached_tokens,
request.track_prefill_tokens, request.track_prefill_tokens,
); );
...@@ -291,7 +292,7 @@ impl< ...@@ -291,7 +292,7 @@ impl<
worker: selection.worker, worker: selection.worker,
lora_name: request.lora_name.clone(), lora_name: request.lora_name.clone(),
}, },
decay_now, Instant::now(),
) { ) {
tracing::warn!("Failed to add request {request_id}: {e}"); tracing::warn!("Failed to add request {request_id}: {e}");
} }
...@@ -300,14 +301,14 @@ impl< ...@@ -300,14 +301,14 @@ impl<
fn prefill_load_hint_for( fn prefill_load_hint_for(
&self, &self,
isl_tokens: usize, isl_tokens: usize,
overlap_blocks: u32, cached_tokens: usize,
track_prefill_tokens: bool, track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> { ) -> Option<PrefillLoadHint> {
if !track_prefill_tokens { if !track_prefill_tokens {
return None; return None;
} }
let prefix = (overlap_blocks as usize) * (self.block_size as usize); let prefix = cached_tokens.min(isl_tokens);
let effective_isl = isl_tokens.saturating_sub(prefix); let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 { if effective_isl == 0 {
return None; return None;
...@@ -408,7 +409,7 @@ mod tests { ...@@ -408,7 +409,7 @@ mod tests {
use tokio::sync::{Barrier, watch}; use tokio::sync::{Barrier, watch};
use super::*; use super::*;
use crate::protocols::{OverlapScores, WorkerSelectionResult, WorkerWithDpRank}; use crate::protocols::{WorkerSelectionResult, WorkerWithDpRank};
use crate::scheduling::types::KvSchedulerError; use crate::scheduling::types::KvSchedulerError;
use crate::sequences::ActiveSequencesMultiWorker; use crate::sequences::ActiveSequencesMultiWorker;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
...@@ -499,7 +500,16 @@ mod tests { ...@@ -499,7 +500,16 @@ mod tests {
Ok(WorkerSelectionResult { Ok(WorkerSelectionResult {
worker, worker,
required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64, required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64,
overlap_blocks: request.overlaps.scores.get(&worker).copied().unwrap_or(0), effective_overlap_blocks: request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0),
cached_tokens: request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0),
}) })
} }
} }
...@@ -628,7 +638,10 @@ mod tests { ...@@ -628,7 +638,10 @@ mod tests {
maybe_request_id: Some(request_id.to_string()), maybe_request_id: Some(request_id.to_string()),
token_seq: None, token_seq: None,
isl_tokens, isl_tokens,
overlaps: OverlapScores::default(), tier_overlap_blocks: Default::default(),
effective_overlap_blocks: HashMap::new(),
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
...@@ -1020,7 +1033,10 @@ mod tests { ...@@ -1020,7 +1033,10 @@ mod tests {
maybe_request_id: Some("filter-0".to_string()), maybe_request_id: Some("filter-0".to_string()),
token_seq: None, token_seq: None,
isl_tokens: isl, isl_tokens: isl,
overlaps: OverlapScores::default(), tier_overlap_blocks: Default::default(),
effective_overlap_blocks: HashMap::new(),
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
......
...@@ -37,59 +37,49 @@ fn softmax_sample_with_sample( ...@@ -37,59 +37,49 @@ fn softmax_sample_with_sample(
temperature: f64, temperature: f64,
sample: f64, sample: f64,
) -> (WorkerWithDpRank, f64) { ) -> (WorkerWithDpRank, f64) {
if logits.is_empty() { assert!(!logits.is_empty(), "Empty logits for softmax sampling");
panic!("Empty logits for softmax sampling");
}
// Guard: at zero temperature, return a minimum-logit worker directly.
if temperature == 0.0 { if temperature == 0.0 {
let mut logit_iter = logits.iter(); let (worker, logit) = logits
let (first_key, first_logit) = logit_iter.next().unwrap(); .iter()
.min_by(|a, b| a.1.total_cmp(b.1))
let mut min_logit = first_logit; .expect("logits non-empty");
let mut min_key = first_key; return (*worker, *logit);
for (key, logit) in logit_iter {
if logit < min_logit {
min_logit = logit;
min_key = key;
}
}
return (*min_key, *min_logit);
} }
let entries: Vec<_> = logits let entries: Vec<(WorkerWithDpRank, f64)> = logits.iter().map(|(w, l)| (*w, *l)).collect();
.iter()
.map(|(worker, logit)| (*worker, *logit))
.collect();
let values: Vec<_> = entries.iter().map(|(_, logit)| *logit).collect();
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b)); let (min_val, max_val) = entries
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); .iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), (_, v)| {
(lo.min(*v), hi.max(*v))
});
let probabilities = if min_val == max_val { let mut probs = if min_val == max_val {
vec![1.0 / entries.len() as f64; entries.len()] vec![1.0 / entries.len() as f64; entries.len()]
} else { } else {
// Fused normalize -> negate -> scale -> exp, then normalize probabilities // Negate logits and rescale to [−1/temperature, 0] for numerical stability
let range = max_val - min_val; // before softmax. Subtracting the max (which maps to min_val) keeps exp() inputs ≤ 0.
let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect(); let scale = -1.0 / ((max_val - min_val) * temperature);
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); let max_scaled = min_val * scale;
let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect(); entries
let sum: f64 = probs.iter().sum(); .iter()
probs.iter_mut().for_each(|p| *p /= sum); .map(|(_, v)| (v * scale - max_scaled).exp())
probs .collect::<Vec<f64>>()
}; };
let sum: f64 = probs.iter().sum();
probs.iter_mut().for_each(|p| *p /= sum);
let mut cumsum = 0.0; let mut cumsum = 0.0;
for (i, &prob) in probabilities.iter().enumerate() { for (i, &prob) in probs.iter().enumerate() {
cumsum += prob; cumsum += prob;
if sample <= cumsum { if sample <= cumsum {
return entries[i]; return entries[i];
} }
} }
// Fallback to last key (shouldn't normally reach here) *entries.last().unwrap()
entries[entries.len() - 1]
} }
/// Default implementation matching the Python _cost_function. /// Default implementation matching the Python _cost_function.
...@@ -99,12 +89,6 @@ pub struct DefaultWorkerSelector { ...@@ -99,12 +89,6 @@ pub struct DefaultWorkerSelector {
pub worker_type: &'static str, pub worker_type: &'static str,
} }
#[derive(Debug, Clone, Copy)]
struct WorkerScore {
overlap_blocks: u32,
logit: f64,
}
impl DefaultWorkerSelector { impl DefaultWorkerSelector {
pub fn new(kv_router_config: Option<KvRouterConfig>, worker_type: &'static str) -> Self { pub fn new(kv_router_config: Option<KvRouterConfig>, worker_type: &'static str) -> Self {
Self { Self {
...@@ -113,7 +97,7 @@ impl DefaultWorkerSelector { ...@@ -113,7 +97,7 @@ impl DefaultWorkerSelector {
} }
} }
fn worker_score( fn worker_logit(
&self, &self,
request: &SchedulingRequest, request: &SchedulingRequest,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
...@@ -121,9 +105,16 @@ impl DefaultWorkerSelector { ...@@ -121,9 +105,16 @@ impl DefaultWorkerSelector {
overlap_weight: f64, overlap_weight: f64,
shared_cache_multiplier: f64, shared_cache_multiplier: f64,
formula_name: &'static str, formula_name: &'static str,
) -> WorkerScore { ) -> f64 {
let isl = request.isl_tokens; let isl = request.isl_tokens;
let overlap_blocks = request.overlaps.scores.get(&worker).copied().unwrap_or(0); let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0);
// `shared_cache_hits::hits_beyond` expects an integer block count, so
// round the weighted overlap for this comparison only.
let device_overlap_blocks = effective_overlap_blocks.round().max(0.0) as u32;
let default_prefill_token = if request.track_prefill_tokens { isl } else { 0 }; let default_prefill_token = if request.track_prefill_tokens { isl } else { 0 };
let prefill_token = request let prefill_token = request
.prefill_tokens .prefill_tokens
...@@ -134,7 +125,7 @@ impl DefaultWorkerSelector { ...@@ -134,7 +125,7 @@ impl DefaultWorkerSelector {
// Adjust prefill tokens by shared cache hits beyond this worker's device prefix. // Adjust prefill tokens by shared cache hits beyond this worker's device prefix.
let (adjusted_prefill_token, shared_beyond) = let (adjusted_prefill_token, shared_beyond) =
if let Some(ref shared_hits) = request.shared_cache_hits { if let Some(ref shared_hits) = request.shared_cache_hits {
let beyond = shared_hits.hits_beyond(overlap_blocks); let beyond = shared_hits.hits_beyond(device_overlap_blocks);
let reduction = shared_cache_multiplier * (beyond as f64) * (block_size as f64); let reduction = shared_cache_multiplier * (beyond as f64) * (block_size as f64);
let adjusted = (prefill_token as f64 - reduction).max(0.0) as usize; let adjusted = (prefill_token as f64 - reduction).max(0.0) as usize;
(adjusted, beyond) (adjusted, beyond)
...@@ -153,7 +144,7 @@ impl DefaultWorkerSelector { ...@@ -153,7 +144,7 @@ impl DefaultWorkerSelector {
if shared_beyond > 0 { if shared_beyond > 0 {
tracing::debug!( tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} device blocks, \ "{formula_name} for worker_id={} dp_rank={:?} with {effective_overlap_blocks:.2} effective device blocks, \
{shared_beyond} shared blocks beyond device (multiplier={shared_cache_multiplier:.2}): {logit:.3} \ {shared_beyond} shared blocks beyond device (multiplier={shared_cache_multiplier:.2}): {logit:.3} \
= {overlap_weight:.1} * adjusted_prefill_blocks + decode_blocks \ = {overlap_weight:.1} * adjusted_prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3} \ = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3} \
...@@ -163,7 +154,7 @@ impl DefaultWorkerSelector { ...@@ -163,7 +154,7 @@ impl DefaultWorkerSelector {
); );
} else { } else {
tracing::debug!( tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} cached blocks: {logit:.3} \ "{formula_name} for worker_id={} dp_rank={:?} with {effective_overlap_blocks:.2} effective cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \ = {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}", = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id, worker.worker_id,
...@@ -171,10 +162,7 @@ impl DefaultWorkerSelector { ...@@ -171,10 +162,7 @@ impl DefaultWorkerSelector {
); );
} }
WorkerScore { logit
overlap_blocks,
logit,
}
} }
} }
...@@ -201,7 +189,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -201,7 +189,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let isl = request.isl_tokens; let isl = request.isl_tokens;
let request_blocks = isl.div_ceil(block_size as usize); let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores;
let overlap_weight = request let overlap_weight = request
.router_config_override .router_config_override
...@@ -218,7 +205,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -218,7 +205,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
if let Some(worker) = pinned_worker { if let Some(worker) = pinned_worker {
pinned_worker_config(workers, worker)?; pinned_worker_config(workers, worker)?;
let score = self.worker_score( let logit = self.worker_logit(
request, request,
worker, worker,
block_size, block_size,
...@@ -226,11 +213,31 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -226,11 +213,31 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
shared_cache_multiplier, shared_cache_multiplier,
"Pinned formula", "Pinned formula",
); );
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0);
let cached_tokens = request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0);
tracing::info!(
"Selected pinned worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, effective cached blocks: {:.2}",
self.worker_type,
worker.worker_id,
worker.dp_rank,
logit,
effective_overlap_blocks,
);
return Ok(WorkerSelectionResult { return Ok(WorkerSelectionResult {
worker, worker,
required_blocks: request_blocks as u64, required_blocks: request_blocks as u64,
overlap_blocks: score.overlap_blocks, effective_overlap_blocks,
cached_tokens,
}); });
} }
...@@ -241,7 +248,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -241,7 +248,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.unwrap_or(self.kv_router_config.router_temperature); .unwrap_or(self.kv_router_config.router_temperature);
let get_score = |worker: WorkerWithDpRank| -> f64 { let get_score = |worker: WorkerWithDpRank| -> f64 {
self.worker_score( self.worker_logit(
request, request,
worker, worker,
block_size, block_size,
...@@ -249,7 +256,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -249,7 +256,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
shared_cache_multiplier, shared_cache_multiplier,
"Formula", "Formula",
) )
.logit
}; };
let worker_iter = workers let worker_iter = workers
...@@ -282,7 +288,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -282,7 +288,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
); );
let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = min_workers let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = min_workers
.iter() .iter()
.map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w)) .map(|w| (request.tree_sizes.get(w).copied().unwrap_or(0), w))
.collect(); .collect();
if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) { if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
...@@ -305,22 +311,58 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -305,22 +311,58 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
softmax_sample(&worker_logits, temperature) softmax_sample(&worker_logits, temperature)
}; };
let best_host_pinned_overlap_blocks = request
.tier_overlap_blocks
.host_pinned
.get(&best_worker)
.copied()
.unwrap_or(0);
let best_disk_overlap_blocks = request
.tier_overlap_blocks
.disk
.get(&best_worker)
.copied()
.unwrap_or(0);
if self.worker_type == "decode" { if self.worker_type == "decode" {
tracing::info!( tracing::info!(
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}", "Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, host_pinned blocks: {}, disk blocks: {}",
self.worker_type, self.worker_type,
best_worker.worker_id, best_worker.worker_id,
best_worker.dp_rank, best_worker.dp_rank,
best_logit, best_logit,
best_host_pinned_overlap_blocks,
best_disk_overlap_blocks,
); );
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&best_worker)
.copied()
.unwrap_or(0.0);
let cached_tokens = request
.effective_cached_tokens
.get(&best_worker)
.copied()
.unwrap_or(0);
return Ok(WorkerSelectionResult { return Ok(WorkerSelectionResult {
worker: best_worker, worker: best_worker,
required_blocks: request_blocks as u64, required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0), effective_overlap_blocks,
cached_tokens,
}); });
} }
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0); let best_overlap = request
.effective_overlap_blocks
.get(&best_worker)
.copied()
.unwrap_or(0.0);
let best_cached_tokens = request
.effective_cached_tokens
.get(&best_worker)
.copied()
.unwrap_or(0);
let total_blocks_info = workers let total_blocks_info = workers
.get(&best_worker.worker_id) .get(&best_worker.worker_id)
...@@ -328,20 +370,17 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -328,20 +370,17 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.map(|blocks| format!(", total blocks: {}", blocks)) .map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default(); .unwrap_or_default();
let tree_size = request let tree_size = request.tree_sizes.get(&best_worker).copied().unwrap_or(0);
.overlaps
.tree_sizes
.get(&best_worker)
.copied()
.unwrap_or(0);
tracing::info!( tracing::info!(
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}", "Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, effective cached blocks: {:.2}, host_pinned blocks: {}, disk blocks: {}, tree size: {}{}",
self.worker_type, self.worker_type,
best_worker.worker_id, best_worker.worker_id,
best_worker.dp_rank, best_worker.dp_rank,
best_logit, best_logit,
best_overlap, best_overlap,
best_host_pinned_overlap_blocks,
best_disk_overlap_blocks,
tree_size, tree_size,
total_blocks_info total_blocks_info
); );
...@@ -349,7 +388,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -349,7 +388,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
Ok(WorkerSelectionResult { Ok(WorkerSelectionResult {
worker: best_worker, worker: best_worker,
required_blocks: request_blocks as u64, required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0), effective_overlap_blocks: best_overlap,
cached_tokens: best_cached_tokens,
}) })
} }
} }
...@@ -479,7 +519,6 @@ mod tests { ...@@ -479,7 +519,6 @@ mod tests {
/// Worker 1 has lower logit (less work), so it wins. /// Worker 1 has lower logit (less work), so it wins.
#[test] #[test]
fn test_shared_cache_hits_scoring() { fn test_shared_cache_hits_scoring() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig; use crate::test_utils::SimpleWorkerConfig;
let block_size = 1u32; let block_size = 1u32;
...@@ -487,8 +526,8 @@ mod tests { ...@@ -487,8 +526,8 @@ mod tests {
let worker0 = WorkerWithDpRank::from_worker_id(0); let worker0 = WorkerWithDpRank::from_worker_id(0);
let worker1 = WorkerWithDpRank::from_worker_id(1); let worker1 = WorkerWithDpRank::from_worker_id(1);
let mut overlaps = OverlapScores::new(); let mut effective_overlap_blocks = HashMap::new();
overlaps.scores.insert(worker0, 2); effective_overlap_blocks.insert(worker0, 2.0);
// worker1 has 0 overlap (not in map) // worker1 has 0 overlap (not in map)
#[allow(clippy::single_range_in_vec_init)] #[allow(clippy::single_range_in_vec_init)]
...@@ -511,7 +550,10 @@ mod tests { ...@@ -511,7 +550,10 @@ mod tests {
maybe_request_id: Some("test".into()), maybe_request_id: Some("test".into()),
token_seq: None, token_seq: None,
isl_tokens: isl, isl_tokens: isl,
overlaps, tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
...@@ -540,15 +582,14 @@ mod tests { ...@@ -540,15 +582,14 @@ mod tests {
/// Without shared cache hits, the scoring should be unchanged. /// Without shared cache hits, the scoring should be unchanged.
#[test] #[test]
fn test_no_shared_cache_unchanged() { fn test_no_shared_cache_unchanged() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig; use crate::test_utils::SimpleWorkerConfig;
let block_size = 16u32; let block_size = 16u32;
let isl = 64usize; let isl = 64usize;
let worker0 = WorkerWithDpRank::from_worker_id(0); let worker0 = WorkerWithDpRank::from_worker_id(0);
let mut overlaps = OverlapScores::new(); let mut effective_overlap_blocks = HashMap::new();
overlaps.scores.insert(worker0, 2); effective_overlap_blocks.insert(worker0, 2.0);
let config = KvRouterConfig::default(); let config = KvRouterConfig::default();
let selector = DefaultWorkerSelector::new(Some(config), "test"); let selector = DefaultWorkerSelector::new(Some(config), "test");
...@@ -560,7 +601,10 @@ mod tests { ...@@ -560,7 +601,10 @@ mod tests {
maybe_request_id: Some("test".into()), maybe_request_id: Some("test".into()),
token_seq: None, token_seq: None,
isl_tokens: isl, isl_tokens: isl,
overlaps, tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(), decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(), prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true, track_prefill_tokens: true,
......
...@@ -8,9 +8,13 @@ use rustc_hash::FxHashMap; ...@@ -8,9 +8,13 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::config::RouterConfigOverride; use super::config::RouterConfigOverride;
use crate::protocols::{ use crate::protocols::{DpRank, SharedCacheHits, WorkerConfigLike, WorkerId, WorkerWithDpRank};
DpRank, OverlapScores, SharedCacheHits, WorkerConfigLike, WorkerId, WorkerWithDpRank,
}; #[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TierOverlapBlocks {
pub host_pinned: HashMap<WorkerWithDpRank, usize>,
pub disk: HashMap<WorkerWithDpRank, usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad { pub struct PotentialLoad {
...@@ -38,14 +42,20 @@ pub enum KvSchedulerError { ...@@ -38,14 +42,20 @@ pub enum KvSchedulerError {
#[derive(Debug)] #[derive(Debug)]
pub struct SchedulingResponse { pub struct SchedulingResponse {
pub best_worker: WorkerWithDpRank, pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32, pub effective_overlap_blocks: f64,
pub cached_tokens: usize,
} }
pub struct SchedulingRequest { pub struct SchedulingRequest {
pub maybe_request_id: Option<String>, pub maybe_request_id: Option<String>,
pub token_seq: Option<Vec<SequenceHash>>, pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize, pub isl_tokens: usize,
pub overlaps: OverlapScores, pub tier_overlap_blocks: TierOverlapBlocks,
pub effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
pub effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
/// Per-worker tree size, used only for tie-breaking when multiple workers
/// produce the same logit at temperature=0.
pub tree_sizes: HashMap<WorkerWithDpRank, usize>,
pub decode_blocks: FxHashMap<WorkerWithDpRank, usize>, pub decode_blocks: FxHashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: FxHashMap<WorkerWithDpRank, usize>, pub prefill_tokens: FxHashMap<WorkerWithDpRank, usize>,
pub track_prefill_tokens: bool, pub track_prefill_tokens: bool,
...@@ -85,16 +95,6 @@ impl SchedulingRequest { ...@@ -85,16 +95,6 @@ impl SchedulingRequest {
}) })
} }
/// Scheduling consumers use the exact pinned-worker overlap when present;
/// otherwise they use the best available overlap across eligible workers.
pub fn overlap_blocks(&self) -> u32 {
if let Some(worker) = self.pinned_worker {
return self.overlaps.scores.get(&worker).copied().unwrap_or(0);
}
self.overlaps.scores.values().copied().max().unwrap_or(0)
}
pub fn bypass_capacity_check(&self) -> bool { pub fn bypass_capacity_check(&self) -> bool {
self.pinned_worker.is_none() && self.allowed_worker_ids.is_some() self.pinned_worker.is_none() && self.allowed_worker_ids.is_some()
} }
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustc_hash::FxHashMap; use rustc_hash::{FxBuildHasher, FxHashMap};
use std::collections::HashMap; use std::collections::HashMap;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
...@@ -26,8 +26,7 @@ use super::request_maps::RequestIndex; ...@@ -26,8 +26,7 @@ use super::request_maps::RequestIndex;
use super::single::{ActiveSequences, PromptMembershipDelta, RequestId}; use super::single::{ActiveSequences, PromptMembershipDelta, RequestId};
use super::topology::WorkerTable; use super::topology::WorkerTable;
use crate::protocols::{ use crate::protocols::{
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, PrefillLoadHint, ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, PrefillLoadHint, WorkerWithDpRank,
WorkerWithDpRank,
}; };
// How often we force expire stale requests across all workers. See the comment // How often we force expire stale requests across all workers. See the comment
...@@ -90,6 +89,9 @@ pub enum SequenceError { ...@@ -90,6 +89,9 @@ pub enum SequenceError {
#[error("Request {request_id} not found")] #[error("Request {request_id} not found")]
RequestNotFound { request_id: String }, RequestNotFound { request_id: String },
#[error("Failed to publish replica-sync event: {0}")]
ReplicaSyncPublishFailed(String),
} }
/// Bundled parameters for adding a request to the sequence tracker. /// Bundled parameters for adding a request to the sequence tracker.
...@@ -587,8 +589,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -587,8 +589,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
&self, &self,
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
isl: usize, isl: usize,
overlaps: OverlapScores, cached_tokens: HashMap<WorkerWithDpRank, usize>,
decay_now: Instant,
) -> ( ) -> (
FxHashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
...@@ -596,9 +597,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -596,9 +597,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.potential_blocks_and_tokens_with_prefill_tracking( self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence, token_sequence,
isl, isl,
overlaps, cached_tokens,
true, true,
decay_now, Instant::now(),
) )
} }
...@@ -606,22 +607,54 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -606,22 +607,54 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
&self, &self,
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
isl: usize, isl: usize,
overlaps: OverlapScores, cached_tokens: HashMap<WorkerWithDpRank, usize>,
track_prefill_tokens: bool, track_prefill_tokens: bool,
decay_now: Instant, decay_now: Instant,
) -> ( ) -> (
FxHashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
) { ) {
self.prompt_registry #[cfg(feature = "bench")]
.potential_blocks_and_tokens_with_prefill_tracking( let start = tokio::time::Instant::now();
token_sequence,
isl, let table = self.workers.read();
&overlaps,
track_prefill_tokens, #[cfg(feature = "bench")]
self.block_size, let num_workers = table.slots.len();
decay_now,
) let mut potential_blocks =
FxHashMap::with_capacity_and_hasher(table.slots.len(), FxBuildHasher);
let mut potential_tokens =
FxHashMap::with_capacity_and_hasher(table.slots.len(), FxBuildHasher);
for slot in &table.slots {
let worker_cached_tokens = cached_tokens.get(&slot.worker).copied().unwrap_or(0);
let (blocks, tokens) = slot
.sequences
.read()
.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
worker_cached_tokens,
track_prefill_tokens,
decay_now,
);
potential_blocks.insert(slot.worker, blocks);
potential_tokens.insert(slot.worker, tokens);
}
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
num_workers,
total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed"
);
}
(potential_blocks, potential_tokens)
} }
/// Query all workers for their current number of active blocks. /// Query all workers for their current number of active blocks.
...@@ -945,7 +978,6 @@ mod tests { ...@@ -945,7 +978,6 @@ mod tests {
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use super::super::prefill_tracker::added_prefill_tokens;
use super::*; use super::*;
use crate::protocols::{ use crate::protocols::{
ActiveSequenceEvent, ActiveSequenceEventData, BlockHashOptions, OverlapScores, ActiveSequenceEvent, ActiveSequenceEventData, BlockHashOptions, OverlapScores,
...@@ -999,6 +1031,7 @@ mod tests { ...@@ -999,6 +1031,7 @@ mod tests {
FxHashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
) { ) {
let cached_tokens = cached_tokens_from_overlap_scores(overlaps, sequences.block_size);
let table = sequences.workers.read(); let table = sequences.workers.read();
let mut potential_blocks = FxHashMap::default(); let mut potential_blocks = FxHashMap::default();
let mut potential_tokens = FxHashMap::default(); let mut potential_tokens = FxHashMap::default();
...@@ -1013,9 +1046,9 @@ mod tests { ...@@ -1013,9 +1046,9 @@ mod tests {
}); });
let new_blocks = let new_blocks =
token_sequence.map_or(0, |query| query.len().saturating_sub(overlap_depth)); token_sequence.map_or(0, |query| query.len().saturating_sub(overlap_depth));
let overlap = *overlaps.scores.get(&slot.worker).unwrap_or(&0); let worker_cached_tokens = *cached_tokens.get(&slot.worker).unwrap_or(&0);
let added_tokens = if track_prefill_tokens { let added_tokens = if track_prefill_tokens {
added_prefill_tokens(sequences.block_size, isl, overlap) seq.new_tokens(isl, worker_cached_tokens)
} else { } else {
0 0
}; };
...@@ -1025,6 +1058,17 @@ mod tests { ...@@ -1025,6 +1058,17 @@ mod tests {
(potential_blocks, potential_tokens) (potential_blocks, potential_tokens)
} }
fn cached_tokens_from_overlap_scores(
overlaps: &OverlapScores,
block_size: usize,
) -> HashMap<WorkerWithDpRank, usize> {
overlaps
.scores
.iter()
.map(|(worker, overlap_blocks)| (*worker, (*overlap_blocks as usize) * block_size))
.collect()
}
fn seq_hashes_for_tokens(tokens: &[u32], lora_name: Option<&str>) -> Vec<SequenceHash> { fn seq_hashes_for_tokens(tokens: &[u32], lora_name: Option<&str>) -> Vec<SequenceHash> {
seq_hashes_for_tokens_with_block_size(tokens, 4, lora_name) seq_hashes_for_tokens_with_block_size(tokens, 4, lora_name)
} }
...@@ -1153,7 +1197,7 @@ mod tests { ...@@ -1153,7 +1197,7 @@ mod tests {
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking( let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&prompt), Some(&prompt),
16, 16,
actual_overlaps, cached_tokens_from_overlap_scores(&actual_overlaps, sequences.block_size),
true, true,
decay_now, decay_now,
); );
...@@ -1214,7 +1258,7 @@ mod tests { ...@@ -1214,7 +1258,7 @@ mod tests {
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking( let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&base_prompt), Some(&base_prompt),
8, 8,
OverlapScores::default(), HashMap::new(),
false, false,
decay_now, decay_now,
); );
...@@ -1278,7 +1322,7 @@ mod tests { ...@@ -1278,7 +1322,7 @@ mod tests {
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking( let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&prompt_b), Some(&prompt_b),
3, 3,
OverlapScores::default(), cached_tokens_from_overlap_scores(&OverlapScores::default(), sequences.block_size),
false, false,
decay_now, decay_now,
); );
...@@ -1299,7 +1343,7 @@ mod tests { ...@@ -1299,7 +1343,7 @@ mod tests {
let actual_after_free = sequences.potential_blocks_and_tokens_with_prefill_tracking( let actual_after_free = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&prompt_b), Some(&prompt_b),
3, 3,
OverlapScores::default(), cached_tokens_from_overlap_scores(&OverlapScores::default(), sequences.block_size),
false, false,
decay_now, decay_now,
); );
...@@ -1390,7 +1434,7 @@ mod tests { ...@@ -1390,7 +1434,7 @@ mod tests {
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking( let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&[1, 2, 3]), Some(&[1, 2, 3]),
12, 12,
OverlapScores::default(), HashMap::new(),
false, false,
Instant::now(), Instant::now(),
); );
...@@ -1595,7 +1639,7 @@ mod tests { ...@@ -1595,7 +1639,7 @@ mod tests {
let (_, potential_tokens) = sequences.potential_blocks_and_tokens_with_prefill_tracking( let (_, potential_tokens) = sequences.potential_blocks_and_tokens_with_prefill_tracking(
None, None,
0, 0,
OverlapScores::default(), HashMap::new(),
false, false,
decay_now, decay_now,
); );
......
...@@ -51,6 +51,7 @@ impl PrefillLoadSnapshot { ...@@ -51,6 +51,7 @@ impl PrefillLoadSnapshot {
} }
} }
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn added_prefill_tokens(block_size: usize, isl: usize, overlap: u32) -> usize { pub(super) fn added_prefill_tokens(block_size: usize, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * block_size; let cached_tokens = (overlap as usize) * block_size;
isl.checked_sub(cached_tokens).unwrap_or_else(|| { isl.checked_sub(cached_tokens).unwrap_or_else(|| {
......
...@@ -534,6 +534,7 @@ impl PromptMembershipTrie { ...@@ -534,6 +534,7 @@ impl PromptMembershipTrie {
} }
} }
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn compute_overlap_depths( pub(super) fn compute_overlap_depths(
&self, &self,
query: Option<&[SequenceHash]>, query: Option<&[SequenceHash]>,
......
...@@ -97,6 +97,7 @@ impl PromptRegistry { ...@@ -97,6 +97,7 @@ impl PromptRegistry {
} }
#[expect(clippy::too_many_arguments)] #[expect(clippy::too_many_arguments)]
#[cfg_attr(not(test), allow(dead_code))]
fn project_loads_from_overlap( fn project_loads_from_overlap(
&self, &self,
query_len: usize, query_len: usize,
...@@ -135,6 +136,7 @@ impl PromptRegistry { ...@@ -135,6 +136,7 @@ impl PromptRegistry {
(potential_blocks, potential_tokens) (potential_blocks, potential_tokens)
} }
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn potential_blocks_and_tokens_with_prefill_tracking( pub(super) fn potential_blocks_and_tokens_with_prefill_tracking(
&self, &self,
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
......
...@@ -109,8 +109,6 @@ pub struct ActiveSequences { ...@@ -109,8 +109,6 @@ pub struct ActiveSequences {
requests: HashMap<RequestId, RequestState>, requests: HashMap<RequestId, RequestState>,
prefill: PrefillLoadTracker, prefill: PrefillLoadTracker,
blocks: BlockTracker, blocks: BlockTracker,
#[cfg(test)]
block_size: usize,
last_expiry_check_time: Instant, last_expiry_check_time: Instant,
} }
...@@ -123,8 +121,6 @@ impl ActiveSequences { ...@@ -123,8 +121,6 @@ impl ActiveSequences {
requests: HashMap::new(), requests: HashMap::new(),
prefill: PrefillLoadTracker::default(), prefill: PrefillLoadTracker::default(),
blocks: BlockTracker::default(), blocks: BlockTracker::default(),
#[cfg(test)]
block_size,
last_expiry_check_time: Instant::now(), last_expiry_check_time: Instant::now(),
} }
} }
...@@ -157,14 +153,12 @@ impl ActiveSequences { ...@@ -157,14 +153,12 @@ impl ActiveSequences {
self.blocks.active_blocks() self.blocks.active_blocks()
} }
#[cfg(test)]
pub(super) fn active_tokens(&self, decay_now: Instant) -> usize { pub(super) fn active_tokens(&self, decay_now: Instant) -> usize {
self.prefill.snapshot().active_tokens_at(decay_now) self.prefill.snapshot().active_tokens_at(decay_now)
} }
/// Add a new request with optional prompt-token load accounting. /// Add a new request with optional prompt-token load accounting.
/// Returns block membership transitions plus any expired request IDs removed during cleanup. /// Returns block membership transitions plus any expired request IDs removed during cleanup.
#[allow(clippy::too_many_arguments)]
pub(super) fn add_request_with_prefill_tracking( pub(super) fn add_request_with_prefill_tracking(
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
...@@ -299,6 +293,31 @@ impl ActiveSequences { ...@@ -299,6 +293,31 @@ impl ActiveSequences {
membership_delta membership_delta
} }
pub fn new_tokens(&self, isl: usize, cached_tokens: usize) -> usize {
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens}, returning 0"
);
0
})
}
pub fn potential_blocks_and_tokens(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
cached_tokens: usize,
decay_now: Instant,
) -> (usize, usize) {
self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
cached_tokens,
true,
decay_now,
)
}
/// Add an output block with a random hash and optional fractional decay weight. /// Add an output block with a random hash and optional fractional decay weight.
/// ///
/// This is used during generation to track output blocks as they are created. /// This is used during generation to track output blocks as they are created.
...@@ -330,12 +349,11 @@ impl ActiveSequences { ...@@ -330,12 +349,11 @@ impl ActiveSequences {
acquire.became_present_on_worker.then_some(random_hash) acquire.became_present_on_worker.then_some(random_hash)
} }
#[cfg(test)] pub fn potential_blocks_and_tokens_with_prefill_tracking(
fn potential_blocks_and_tokens_with_prefill_tracking(
&self, &self,
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
isl: usize, isl: usize,
overlap: u32, cached_tokens: usize,
track_prefill_tokens: bool, track_prefill_tokens: bool,
decay_now: Instant, decay_now: Instant,
) -> (usize, usize) { ) -> (usize, usize) {
...@@ -346,7 +364,7 @@ impl ActiveSequences { ...@@ -346,7 +364,7 @@ impl ActiveSequences {
}; };
let active_tokens = self.active_tokens(decay_now); let active_tokens = self.active_tokens(decay_now);
let potential_tokens = if track_prefill_tokens { let potential_tokens = if track_prefill_tokens {
added_prefill_tokens(self.block_size, isl, overlap) + active_tokens self.new_tokens(isl, cached_tokens) + active_tokens
} else { } else {
active_tokens active_tokens
}; };
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
//! //!
//! - This module is responsible for maintaining a registry of all blocks currently within a pool. //! - This module is responsible for maintaining a registry of all blocks currently within a pool.
//! This consists of two components: A global registry of all blocks, and a per-pool registry of blocks. //! This consists of two components: A global registry of all blocks, and a per-pool registry of blocks.
//! - The global registry is a mapping of sequences hashes to registration handles. If two blocks in different pools //! - The global registry is keyed by sequence hash and storage tier. If two blocks in different pools
//! have the same sequence hash, then they will share the same registration handle. The global registry is shared across all pools. //! have the same sequence hash but live in different tiers, they keep distinct registration handles
//! so KVBM can emit per-tier events. The global registry is shared across all pools.
//! - The per-pool registry is a mapping of sequence hashes to block handles. This is used to track which blocks are //! - The per-pool registry is a mapping of sequence hashes to block handles. This is used to track which blocks are
//! currently within a specific pool. The block handle is unique across pools, and is used to track the block's lifetime. //! currently within a specific pool. The block handle is unique across pools, and is used to track the block's lifetime.
//! - When a block is in the registered state, it has a unique block handle and a possibly shared registration handle. //! - When a block is in the registered state, it has a unique block handle and a possibly shared registration handle.
...@@ -27,12 +28,28 @@ use std::{ ...@@ -27,12 +28,28 @@ use std::{
use super::super::events::{EventManager, EventReleaseManager, PublishHandle}; use super::super::events::{EventManager, EventReleaseManager, PublishHandle};
use super::state::BlockState; use super::state::BlockState;
use crate::block_manager::kv_consolidator::StorageTier;
use crate::tokens::{BlockHash, SequenceHash, TokenBlock}; use crate::tokens::{BlockHash, SequenceHash, TokenBlock};
use derive_getters::Getters; use derive_getters::Getters;
use tokio::{runtime::Handle, sync::mpsc}; use tokio::{runtime::Handle, sync::mpsc};
pub type GlobalRegistry = Arc<Mutex<HashMap<SequenceHash, Weak<RegistrationHandle>>>>; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RegistrationKey {
sequence_hash: SequenceHash,
storage_tier: StorageTier,
}
impl RegistrationKey {
fn new(sequence_hash: SequenceHash, storage_tier: StorageTier) -> Self {
Self {
sequence_hash,
storage_tier,
}
}
}
pub type GlobalRegistry = Arc<Mutex<HashMap<RegistrationKey, Weak<RegistrationHandle>>>>;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum BlockRegistrationError { pub enum BlockRegistrationError {
...@@ -72,6 +89,7 @@ impl Drop for BlockHandle { ...@@ -72,6 +89,7 @@ impl Drop for BlockHandle {
pub struct BlockRegistry { pub struct BlockRegistry {
blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>>, blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>>,
storage_tier: StorageTier,
event_manager: Arc<dyn EventManager>, event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry, global_registry: GlobalRegistry,
unregister_tx: mpsc::UnboundedSender<SequenceHash>, unregister_tx: mpsc::UnboundedSender<SequenceHash>,
...@@ -82,6 +100,7 @@ impl BlockRegistry { ...@@ -82,6 +100,7 @@ impl BlockRegistry {
event_manager: Arc<dyn EventManager>, event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry, global_registry: GlobalRegistry,
async_runtime: Handle, async_runtime: Handle,
storage_tier: StorageTier,
) -> Self { ) -> Self {
let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel(); let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel();
...@@ -105,17 +124,19 @@ impl BlockRegistry { ...@@ -105,17 +124,19 @@ impl BlockRegistry {
} }
let mut global_registry = global_registry.lock().unwrap(); let mut global_registry = global_registry.lock().unwrap();
let registration_key = RegistrationKey::new(sequence_hash, storage_tier);
if let Some(entry) = global_registry.get(&sequence_hash) if let Some(entry) = global_registry.get(&registration_key)
&& entry.upgrade().is_none() && entry.upgrade().is_none()
{ {
global_registry.remove(&sequence_hash); global_registry.remove(&registration_key);
} }
} }
}); });
Self { Self {
blocks, blocks,
storage_tier,
event_manager, event_manager,
global_registry, global_registry,
unregister_tx, unregister_tx,
...@@ -165,9 +186,10 @@ impl BlockRegistry { ...@@ -165,9 +186,10 @@ impl BlockRegistry {
let reg_handle = 'reg_block: { let reg_handle = 'reg_block: {
// Now, check the global registry. // Now, check the global registry.
let mut global_registry = self.global_registry.lock().unwrap(); let mut global_registry = self.global_registry.lock().unwrap();
let registration_key = RegistrationKey::new(sequence_hash, self.storage_tier);
// If an identical block exists in other pool, use the same registration handle. // If an identical block exists in other pool, use the same registration handle.
if let Some(handle) = global_registry.get(&sequence_hash) if let Some(handle) = global_registry.get(&registration_key)
&& let Some(handle) = handle.upgrade() && let Some(handle) = handle.upgrade()
{ {
break 'reg_block handle; break 'reg_block handle;
...@@ -177,11 +199,12 @@ impl BlockRegistry { ...@@ -177,11 +199,12 @@ impl BlockRegistry {
publish_handle = Some(Self::create_publish_handle( publish_handle = Some(Self::create_publish_handle(
state.token_block(), state.token_block(),
self.event_manager.clone(), self.event_manager.clone(),
self.storage_tier,
)); ));
let reg_handle = publish_handle.as_ref().unwrap().remove_handle(); let reg_handle = publish_handle.as_ref().unwrap().remove_handle();
// Insert the registration handle into the global registry. // Insert the registration handle into the global registry.
global_registry.insert(sequence_hash, Arc::downgrade(&reg_handle)); global_registry.insert(registration_key, Arc::downgrade(&reg_handle));
reg_handle reg_handle
}; };
...@@ -205,8 +228,10 @@ impl BlockRegistry { ...@@ -205,8 +228,10 @@ impl BlockRegistry {
fn create_publish_handle( fn create_publish_handle(
token_block: &TokenBlock, token_block: &TokenBlock,
event_manager: Arc<dyn EventManager>, event_manager: Arc<dyn EventManager>,
storage_tier: StorageTier,
) -> PublishHandle { ) -> PublishHandle {
let reg_handle = RegistrationHandle::from_token_block(token_block, event_manager.clone()); let reg_handle =
RegistrationHandle::from_token_block(token_block, event_manager.clone(), storage_tier);
PublishHandle::new(reg_handle, event_manager) PublishHandle::new(reg_handle, event_manager)
} }
...@@ -223,6 +248,15 @@ pub struct RegistrationHandle { ...@@ -223,6 +248,15 @@ pub struct RegistrationHandle {
#[getter(copy)] #[getter(copy)]
parent_sequence_hash: Option<SequenceHash>, parent_sequence_hash: Option<SequenceHash>,
#[getter(copy)]
external_sequence_hash: Option<SequenceHash>,
#[getter(copy)]
external_parent_sequence_hash: Option<SequenceHash>,
#[getter(copy)]
storage_tier: StorageTier,
#[getter(skip)] #[getter(skip)]
release_manager: Arc<dyn EventReleaseManager>, release_manager: Arc<dyn EventReleaseManager>,
...@@ -240,14 +274,29 @@ impl RegistrationHandle { ...@@ -240,14 +274,29 @@ impl RegistrationHandle {
self.token_block.tokens() self.token_block.tokens()
} }
/// Returns the router-facing sequence hash for this block.
pub fn published_sequence_hash(&self) -> SequenceHash {
self.external_sequence_hash.unwrap_or(self.sequence_hash)
}
/// Returns the router-facing parent sequence hash for this block.
pub fn published_parent_sequence_hash(&self) -> Option<SequenceHash> {
self.external_parent_sequence_hash
.or(self.parent_sequence_hash)
}
fn from_token_block( fn from_token_block(
token_block: &TokenBlock, token_block: &TokenBlock,
release_manager: Arc<dyn EventReleaseManager>, release_manager: Arc<dyn EventReleaseManager>,
storage_tier: StorageTier,
) -> Self { ) -> Self {
Self { Self {
block_hash: token_block.block_hash(), block_hash: token_block.block_hash(),
sequence_hash: token_block.sequence_hash(), sequence_hash: token_block.sequence_hash(),
parent_sequence_hash: token_block.parent_sequence_hash(), parent_sequence_hash: token_block.parent_sequence_hash(),
external_sequence_hash: token_block.external_sequence_hash(),
external_parent_sequence_hash: token_block.external_parent_sequence_hash(),
storage_tier,
release_manager, release_manager,
token_block: token_block.clone(), token_block: token_block.clone(),
} }
...@@ -258,8 +307,13 @@ impl std::fmt::Debug for RegistrationHandle { ...@@ -258,8 +307,13 @@ impl std::fmt::Debug for RegistrationHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!( write!(
f, f,
"RegistrationHandle {{ sequence_hash: {}; block_hash: {}; parent_sequence_hash: {:?} }}", "RegistrationHandle {{ sequence_hash: {}; block_hash: {}; parent_sequence_hash: {:?}; external_sequence_hash: {:?}; external_parent_sequence_hash: {:?}; storage_tier: {:?} }}",
self.sequence_hash, self.block_hash, self.parent_sequence_hash self.sequence_hash,
self.block_hash,
self.parent_sequence_hash,
self.external_sequence_hash,
self.external_parent_sequence_hash,
self.storage_tier
) )
} }
} }
...@@ -274,7 +328,9 @@ impl Drop for RegistrationHandle { ...@@ -274,7 +328,9 @@ impl Drop for RegistrationHandle {
mod tests { mod tests {
use super::*; use super::*;
use crate::block_manager::events::NullEventManager;
use crate::block_manager::events::tests::{EventType, MockEventManager}; use crate::block_manager::events::tests::{EventType, MockEventManager};
use crate::block_manager::kv_consolidator::StorageTier;
use crate::tokens::{TokenBlockSequence, Tokens}; use crate::tokens::{TokenBlockSequence, Tokens};
fn create_sequence() -> TokenBlockSequence { fn create_sequence() -> TokenBlockSequence {
...@@ -303,8 +359,11 @@ mod tests { ...@@ -303,8 +359,11 @@ mod tests {
let (event_manager, mut rx) = MockEventManager::new(); let (event_manager, mut rx) = MockEventManager::new();
let publish_handle = let publish_handle = BlockRegistry::create_publish_handle(
BlockRegistry::create_publish_handle(&sequence.blocks()[0], event_manager.clone()); &sequence.blocks()[0],
event_manager.clone(),
StorageTier::Device,
);
// no event should have been triggered // no event should have been triggered
assert!(rx.try_recv().is_err()); assert!(rx.try_recv().is_err());
...@@ -317,7 +376,7 @@ mod tests { ...@@ -317,7 +376,7 @@ mod tests {
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
assert_eq!( assert_eq!(
events[0], events[0],
EventType::Register(sequence.blocks()[0].sequence_hash()) EventType::Register(sequence.blocks()[0].sequence_hash(), StorageTier::Device)
); );
// the second event should be a Remove event // the second event should be a Remove event
...@@ -325,7 +384,7 @@ mod tests { ...@@ -325,7 +384,7 @@ mod tests {
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
assert_eq!( assert_eq!(
events[0], events[0],
EventType::Remove(sequence.blocks()[0].sequence_hash()) EventType::Remove(sequence.blocks()[0].sequence_hash(), StorageTier::Device)
); );
// there should be no more events // there should be no more events
...@@ -340,8 +399,11 @@ mod tests { ...@@ -340,8 +399,11 @@ mod tests {
let (event_manager, mut rx) = MockEventManager::new(); let (event_manager, mut rx) = MockEventManager::new();
let publish_handle = let publish_handle = BlockRegistry::create_publish_handle(
BlockRegistry::create_publish_handle(block_to_test, event_manager.clone()); block_to_test,
event_manager.clone(),
StorageTier::Device,
);
// Remove the registration handle before dropping the publish handle // Remove the registration handle before dropping the publish handle
let reg_handle = publish_handle.remove_handle(); let reg_handle = publish_handle.remove_handle();
...@@ -359,7 +421,7 @@ mod tests { ...@@ -359,7 +421,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
register_events[0], register_events[0],
EventType::Register(expected_sequence_hash), EventType::Register(expected_sequence_hash, StorageTier::Device),
"Expected Register event" "Expected Register event"
); );
...@@ -370,7 +432,7 @@ mod tests { ...@@ -370,7 +432,7 @@ mod tests {
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
assert_eq!( assert_eq!(
events[0], events[0],
EventType::Remove(expected_sequence_hash), EventType::Remove(expected_sequence_hash, StorageTier::Device),
"Only Remove event should be triggered" "Only Remove event should be triggered"
); );
...@@ -389,8 +451,16 @@ mod tests { ...@@ -389,8 +451,16 @@ mod tests {
let (event_manager, mut rx) = MockEventManager::new(); let (event_manager, mut rx) = MockEventManager::new();
let mut publisher = event_manager.publisher(); let mut publisher = event_manager.publisher();
let publish_handle1 = BlockRegistry::create_publish_handle(block1, event_manager.clone()); let publish_handle1 = BlockRegistry::create_publish_handle(
let publish_handle2 = BlockRegistry::create_publish_handle(block2, event_manager.clone()); block1,
event_manager.clone(),
StorageTier::Device,
);
let publish_handle2 = BlockRegistry::create_publish_handle(
block2,
event_manager.clone(),
StorageTier::Device,
);
// Remove handles before adding to publisher // Remove handles before adding to publisher
let reg_handle1 = publish_handle1.remove_handle(); let reg_handle1 = publish_handle1.remove_handle();
...@@ -413,8 +483,8 @@ mod tests { ...@@ -413,8 +483,8 @@ mod tests {
"Should receive two Register events in one batch" "Should receive two Register events in one batch"
); );
// Order isn't guaranteed, so check for both // Order isn't guaranteed, so check for both
assert!(events.contains(&EventType::Register(hash1))); assert!(events.contains(&EventType::Register(hash1, StorageTier::Device)));
assert!(events.contains(&EventType::Register(hash2))); assert!(events.contains(&EventType::Register(hash2, StorageTier::Device)));
// no more events immediately after publish // no more events immediately after publish
assert!(rx.try_recv().is_err()); assert!(rx.try_recv().is_err());
...@@ -423,12 +493,12 @@ mod tests { ...@@ -423,12 +493,12 @@ mod tests {
drop(reg_handle1); drop(reg_handle1);
let events1 = rx.try_recv().unwrap(); let events1 = rx.try_recv().unwrap();
assert_eq!(events1.len(), 1); assert_eq!(events1.len(), 1);
assert_eq!(events1[0], EventType::Remove(hash1)); assert_eq!(events1[0], EventType::Remove(hash1, StorageTier::Device));
drop(reg_handle2); drop(reg_handle2);
let events2 = rx.try_recv().unwrap(); let events2 = rx.try_recv().unwrap();
assert_eq!(events2.len(), 1); assert_eq!(events2.len(), 1);
assert_eq!(events2[0], EventType::Remove(hash2)); assert_eq!(events2[0], EventType::Remove(hash2, StorageTier::Device));
// no more events // no more events
assert!(rx.try_recv().is_err()); assert!(rx.try_recv().is_err());
...@@ -453,7 +523,11 @@ mod tests { ...@@ -453,7 +523,11 @@ mod tests {
let (event_manager, mut rx) = MockEventManager::new(); let (event_manager, mut rx) = MockEventManager::new();
let mut publisher = event_manager.publisher(); let mut publisher = event_manager.publisher();
let publish_handle1 = BlockRegistry::create_publish_handle(block1, event_manager.clone()); let publish_handle1 = BlockRegistry::create_publish_handle(
block1,
event_manager.clone(),
StorageTier::Device,
);
publisher.take_handle(publish_handle1); publisher.take_handle(publish_handle1);
...@@ -461,7 +535,7 @@ mod tests { ...@@ -461,7 +535,7 @@ mod tests {
publisher.publish(); publisher.publish();
let events = rx.try_recv().unwrap(); let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
assert_eq!(events[0], EventType::Register(hash1)); assert_eq!(events[0], EventType::Register(hash1, StorageTier::Device));
// The RegistrationHandle Arc was taken by the publisher and dropped after the publish call // The RegistrationHandle Arc was taken by the publisher and dropped after the publish call
// So, the Remove event should follow immediately. // So, the Remove event should follow immediately.
...@@ -473,7 +547,7 @@ mod tests { ...@@ -473,7 +547,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
remove_events[0], remove_events[0],
EventType::Remove(hash1), EventType::Remove(hash1, StorageTier::Device),
"Expected Remove event" "Expected Remove event"
); );
...@@ -485,4 +559,89 @@ mod tests { ...@@ -485,4 +559,89 @@ mod tests {
drop(publisher); drop(publisher);
assert!(rx.try_recv().is_err()); assert!(rx.try_recv().is_err());
} }
#[tokio::test(flavor = "current_thread")]
async fn test_same_sequence_in_different_tiers_emits_distinct_events() {
let sequence = create_sequence();
let block = sequence.blocks()[0].clone();
let sequence_hash = block.sequence_hash();
let (event_manager, mut rx) = MockEventManager::new();
let global_registry = GlobalRegistry::default();
let mut host_registry = BlockRegistry::new(
event_manager.clone(),
global_registry.clone(),
Handle::current(),
StorageTier::HostPinned,
);
let mut disk_registry = BlockRegistry::new(
event_manager.clone(),
global_registry,
Handle::current(),
StorageTier::Disk,
);
let mut host_state = BlockState::Reset;
host_state.apply_token_block(block.clone()).unwrap();
let host_publish = host_registry
.register_block(&mut host_state)
.unwrap()
.unwrap();
drop(host_publish);
assert_eq!(
rx.recv().await.unwrap(),
vec![EventType::Register(sequence_hash, StorageTier::HostPinned)]
);
let mut disk_state = BlockState::Reset;
disk_state.apply_token_block(block).unwrap();
let disk_publish = disk_registry
.register_block(&mut disk_state)
.unwrap()
.unwrap();
drop(disk_publish);
assert_eq!(
rx.recv().await.unwrap(),
vec![EventType::Register(sequence_hash, StorageTier::Disk)]
);
drop(host_state);
assert_eq!(
rx.recv().await.unwrap(),
vec![EventType::Remove(sequence_hash, StorageTier::HostPinned)]
);
drop(disk_state);
assert_eq!(
rx.recv().await.unwrap(),
vec![EventType::Remove(sequence_hash, StorageTier::Disk)]
);
}
#[test]
fn test_registration_handle_prefers_external_hashes_for_publication() {
let mut sequence = create_sequence();
sequence.sync_external_sequence_hashes(&[50_001, 50_002]);
let release_manager = NullEventManager::new();
let registration_handle = RegistrationHandle::from_token_block(
&sequence.blocks()[1],
release_manager,
StorageTier::HostPinned,
);
assert_eq!(registration_handle.external_sequence_hash(), Some(50_002));
assert_eq!(
registration_handle.external_parent_sequence_hash(),
Some(50_001)
);
assert_eq!(registration_handle.published_sequence_hash(), 50_002);
assert_eq!(
registration_handle.published_parent_sequence_hash(),
Some(50_001)
);
}
} }
...@@ -230,6 +230,7 @@ impl KvBlockManagerConfigBuilder { ...@@ -230,6 +230,7 @@ impl KvBlockManagerConfigBuilder {
engine_endpoint: String, engine_endpoint: String,
output_endpoint: Option<String>, output_endpoint: Option<String>,
engine_source: crate::block_manager::kv_consolidator::EventSource, engine_source: crate::block_manager::kv_consolidator::EventSource,
mode: crate::block_manager::kv_consolidator::KvEventConsolidationMode,
) -> Self { ) -> Self {
let config = match engine_source { let config = match engine_source {
crate::block_manager::kv_consolidator::EventSource::Vllm => { crate::block_manager::kv_consolidator::EventSource::Vllm => {
...@@ -237,6 +238,7 @@ impl KvBlockManagerConfigBuilder { ...@@ -237,6 +238,7 @@ impl KvBlockManagerConfigBuilder {
crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_vllm( crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_vllm(
engine_endpoint, engine_endpoint,
output_ep, output_ep,
mode,
) )
} }
crate::block_manager::kv_consolidator::EventSource::Trtllm => { crate::block_manager::kv_consolidator::EventSource::Trtllm => {
...@@ -248,6 +250,7 @@ impl KvBlockManagerConfigBuilder { ...@@ -248,6 +250,7 @@ impl KvBlockManagerConfigBuilder {
crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_trtllm( crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_trtllm(
engine_endpoint, engine_endpoint,
output_ep, output_ep,
mode,
) )
} }
crate::block_manager::kv_consolidator::EventSource::Kvbm => { crate::block_manager::kv_consolidator::EventSource::Kvbm => {
......
...@@ -201,18 +201,21 @@ impl DynamoEventManager { ...@@ -201,18 +201,21 @@ impl DynamoEventManager {
rt.spawn(async move { rt.spawn(async move {
for handle in handles { for handle in handles {
// Extract block metadata from RegistrationHandle // Extract block metadata from RegistrationHandle
let block_hash = handle.sequence_hash().to_string(); let block_hash = handle.published_sequence_hash().to_string();
let parent_hash = handle.parent_sequence_hash().map(|h| h.to_string()); let parent_hash = handle
.published_parent_sequence_hash()
.map(|h| h.to_string());
// Extract block_size and tokens from RegistrationHandle // Extract block_size and tokens from RegistrationHandle
let block_size = handle.block_size(); // usize let block_size = handle.block_size(); // usize
let tokens: Vec<u32> = handle.tokens().iter().copied().collect(); let tokens: Vec<u32> = handle.tokens().iter().copied().collect();
tracing::debug!( tracing::debug!(
"DynamoEventManager sending store event to kv event consolidator: block_hash={}, block_size={}, tokens={}", "DynamoEventManager sending store event to kv event consolidator: block_hash={}, block_size={}, tokens={}, tier={:?}",
block_hash, block_hash,
block_size, block_size,
tokens.len() tokens.len(),
handle.storage_tier()
); );
// Send to consolidator with EventSource::Kvbm // Send to consolidator with EventSource::Kvbm
...@@ -224,7 +227,7 @@ impl DynamoEventManager { ...@@ -224,7 +227,7 @@ impl DynamoEventManager {
parent_hash, parent_hash,
block_size, block_size,
None, // lora_name None, // lora_name
None, // tier Some(handle.storage_tier()),
None, // data_parallel_rank None, // data_parallel_rank
) )
.await; .await;
...@@ -242,11 +245,13 @@ impl DynamoEventManager { ...@@ -242,11 +245,13 @@ impl DynamoEventManager {
/// ///
/// Called when a RegistrationHandle is dropped (block evicted from KVBM). /// Called when a RegistrationHandle is dropped (block evicted from KVBM).
fn publish_remove_event(&self, registration_handle: &RegistrationHandle) { fn publish_remove_event(&self, registration_handle: &RegistrationHandle) {
let block_hash = registration_handle.sequence_hash().to_string(); let block_hash = registration_handle.published_sequence_hash().to_string();
let tier = registration_handle.storage_tier();
tracing::debug!( tracing::debug!(
"DynamoEventManager::publish_remove_event called: block_hash={}", %block_hash,
block_hash ?tier,
"DynamoEventManager sending remove event to kv event consolidator"
); );
let kv_event_consolidator = self.consolidator_handle.clone(); let kv_event_consolidator = self.consolidator_handle.clone();
...@@ -254,7 +259,7 @@ impl DynamoEventManager { ...@@ -254,7 +259,7 @@ impl DynamoEventManager {
if let Ok(rt) = tokio::runtime::Handle::try_current() { if let Ok(rt) = tokio::runtime::Handle::try_current() {
rt.spawn(async move { rt.spawn(async move {
kv_event_consolidator kv_event_consolidator
.handle_remove(&block_hash, EventSource::Kvbm) .handle_remove(&block_hash, EventSource::Kvbm, Some(tier))
.await; .await;
}); });
} else { } else {
...@@ -288,14 +293,15 @@ impl EventReleaseManager for DynamoEventManager { ...@@ -288,14 +293,15 @@ impl EventReleaseManager for DynamoEventManager {
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use crate::block_manager::kv_consolidator::StorageTier;
use crate::tokens::SequenceHash; use crate::tokens::SequenceHash;
use super::*; use super::*;
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub enum EventType { pub enum EventType {
Register(SequenceHash), Register(SequenceHash, StorageTier),
Remove(SequenceHash), Remove(SequenceHash, StorageTier),
} }
pub struct MockEventManager { pub struct MockEventManager {
...@@ -322,7 +328,7 @@ pub mod tests { ...@@ -322,7 +328,7 @@ pub mod tests {
fn publish(&self, handles: Vec<Arc<RegistrationHandle>>) { fn publish(&self, handles: Vec<Arc<RegistrationHandle>>) {
let events = handles let events = handles
.iter() .iter()
.map(|handle| EventType::Register(handle.sequence_hash())) .map(|handle| EventType::Register(handle.sequence_hash(), handle.storage_tier()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.tx.send(events).unwrap(); self.tx.send(events).unwrap();
} }
...@@ -330,7 +336,10 @@ pub mod tests { ...@@ -330,7 +336,10 @@ pub mod tests {
impl EventReleaseManager for MockEventManager { impl EventReleaseManager for MockEventManager {
fn block_release(&self, registration_handle: &RegistrationHandle) { fn block_release(&self, registration_handle: &RegistrationHandle) {
let events = vec![EventType::Remove(registration_handle.sequence_hash())]; let events = vec![EventType::Remove(
registration_handle.sequence_hash(),
registration_handle.storage_tier(),
)];
self.tx.send(events).unwrap(); self.tx.send(events).unwrap();
} }
} }
......
...@@ -7,6 +7,35 @@ use serde::{Deserialize, Serialize}; ...@@ -7,6 +7,35 @@ use serde::{Deserialize, Serialize};
use super::tracker::EventSource; use super::tracker::EventSource;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum KvEventConsolidationMode {
#[default]
Dedup,
Passthrough,
}
impl KvEventConsolidationMode {
pub fn as_str(self) -> &'static str {
match self {
Self::Dedup => "dedup",
Self::Passthrough => "passthrough",
}
}
}
impl std::str::FromStr for KvEventConsolidationMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim().to_ascii_lowercase().as_str() {
"dedup" => Ok(Self::Dedup),
"passthrough" => Ok(Self::Passthrough),
_ => Err(format!("Unknown KV event consolidator mode: {s}")),
}
}
}
/// Configuration for the KV Event Consolidator /// Configuration for the KV Event Consolidator
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvEventConsolidatorConfig { pub struct KvEventConsolidatorConfig {
...@@ -19,6 +48,9 @@ pub struct KvEventConsolidatorConfig { ...@@ -19,6 +48,9 @@ pub struct KvEventConsolidatorConfig {
/// Engine source for events (vLLM or TensorRT-LLM) /// Engine source for events (vLLM or TensorRT-LLM)
pub engine_source: EventSource, pub engine_source: EventSource,
/// How the consolidator should process store/remove events.
pub mode: KvEventConsolidationMode,
} }
impl Default for KvEventConsolidatorConfig { impl Default for KvEventConsolidatorConfig {
...@@ -27,6 +59,7 @@ impl Default for KvEventConsolidatorConfig { ...@@ -27,6 +59,7 @@ impl Default for KvEventConsolidatorConfig {
engine_event_endpoint: "tcp://localhost:5557".to_string(), engine_event_endpoint: "tcp://localhost:5557".to_string(),
consolidated_event_endpoint: "tcp://*:5558".to_string(), consolidated_event_endpoint: "tcp://*:5558".to_string(),
engine_source: EventSource::Vllm, engine_source: EventSource::Vllm,
mode: KvEventConsolidationMode::Dedup,
} }
} }
} }
...@@ -36,29 +69,41 @@ impl KvEventConsolidatorConfig { ...@@ -36,29 +69,41 @@ impl KvEventConsolidatorConfig {
engine_event_endpoint: String, engine_event_endpoint: String,
consolidated_event_endpoint: String, consolidated_event_endpoint: String,
engine_source: EventSource, engine_source: EventSource,
mode: KvEventConsolidationMode,
) -> Self { ) -> Self {
Self { Self {
engine_event_endpoint, engine_event_endpoint,
consolidated_event_endpoint, consolidated_event_endpoint,
engine_source, engine_source,
mode,
} }
} }
/// Create config for vLLM /// Create config for vLLM
pub fn new_vllm(engine_event_endpoint: String, consolidated_event_endpoint: String) -> Self { pub fn new_vllm(
engine_event_endpoint: String,
consolidated_event_endpoint: String,
mode: KvEventConsolidationMode,
) -> Self {
Self { Self {
engine_event_endpoint, engine_event_endpoint,
consolidated_event_endpoint, consolidated_event_endpoint,
engine_source: EventSource::Vllm, engine_source: EventSource::Vllm,
mode,
} }
} }
/// Create config for TensorRT-LLM /// Create config for TensorRT-LLM
pub fn new_trtllm(engine_event_endpoint: String, consolidated_event_endpoint: String) -> Self { pub fn new_trtllm(
engine_event_endpoint: String,
consolidated_event_endpoint: String,
mode: KvEventConsolidationMode,
) -> Self {
Self { Self {
engine_event_endpoint, engine_event_endpoint,
consolidated_event_endpoint, consolidated_event_endpoint,
engine_source: EventSource::Trtllm, engine_source: EventSource::Trtllm,
mode,
} }
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment