Unverified Commit e3d00b89 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Tier-based KV Routing (#8380)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7e48f3bd
......@@ -44,6 +44,7 @@ fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) {
}
mod kv_indexer;
mod local;
mod lower_tier;
mod metrics;
mod thread_pool;
mod traits;
......@@ -62,6 +63,7 @@ mod tests;
pub use branch_sharded::*;
pub use kv_indexer::*;
pub use local::*;
pub use lower_tier::*;
pub use metrics::*;
pub use thread_pool::*;
pub use traits::*;
......
......@@ -23,7 +23,7 @@ use std::{
use rustc_hash::{FxHashMap, FxHashSet};
use super::{EventWarningKind, PreBoundEventCounters};
use super::{EventWarningKind, MatchDetails, PreBoundEventCounters};
use crate::active_set::reconcile_active_workers;
use crate::protocols::*;
......@@ -162,12 +162,20 @@ impl RadixTree {
///
/// ### Returns
///
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new();
/// A `MatchDetails` representing overlap scores plus continuation state.
pub fn find_match_details(
&self,
sequence: Vec<LocalBlockHash>,
early_exit: bool,
) -> MatchDetails {
let mut details = MatchDetails::new();
let MatchDetails {
overlap_scores: scores,
last_matched_hashes,
} = &mut details;
if sequence.is_empty() {
return scores;
return details;
}
let now = Instant::now();
......@@ -184,7 +192,7 @@ impl RadixTree {
};
let Some(first_child) = first_child else {
return scores;
return details;
};
// Initialize active worker set from first child.
......@@ -208,12 +216,18 @@ impl RadixTree {
}
if active.is_empty() {
return scores;
return details;
}
let mut current_hash = first_child
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
if early_exit && active_count == 1 {
for worker in &active {
scores.scores.insert(*worker, 1);
last_matched_hashes.insert(*worker, current_hash);
}
for worker in scores.scores.keys() {
let tree_size = self
......@@ -223,7 +237,7 @@ impl RadixTree {
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
return scores;
return details;
}
let mut current = first_child;
......@@ -256,6 +270,7 @@ impl RadixTree {
if child_count != active_count {
reconcile_active_workers(&mut active, &borrow.workers, |worker| {
scores.scores.insert(worker, matched_depth);
last_matched_hashes.insert(worker, current_hash);
});
active_count = active.len();
}
......@@ -281,9 +296,17 @@ impl RadixTree {
if early_exit && active_count == 1 {
matched_depth = (idx + 1) as u32;
current_hash = block
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
break;
}
current_hash = block
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
current = block;
matched_depth = (idx + 1) as u32;
}
......@@ -291,6 +314,7 @@ impl RadixTree {
// Record scores for workers that survived through the deepest matched level.
for worker in &active {
scores.scores.insert(*worker, matched_depth);
last_matched_hashes.insert(*worker, current_hash);
}
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
......@@ -305,7 +329,12 @@ impl RadixTree {
scores.tree_sizes.insert(*worker, tree_size);
}
scores
details
}
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
self.find_match_details(sequence, early_exit).overlap_scores
}
/// Apply a [`RouterEvent`] to the radix tree.
......
......@@ -9,6 +9,7 @@ use tokio::sync::oneshot;
use crate::protocols::*;
use dynamo_tokens::SequenceHash;
use rustc_hash::FxHashMap;
/// Trait for types that may represent an error response.
/// Used for RPC-style responses that can indicate success or failure.
......@@ -250,6 +251,21 @@ impl dynamo_runtime::protocols::maybe_error::MaybeError for IndexerRecordRouting
}
}
/// Rich non-wire query result for router-local device tier lookups.
#[derive(Debug, Clone, Default)]
pub struct MatchDetails {
/// Existing overlap scores used by scheduling.
pub overlap_scores: OverlapScores,
/// Last matched device sequence hash per worker, used to seed lower-tier queries.
pub last_matched_hashes: FxHashMap<WorkerWithDpRank, ExternalSequenceBlockHash>,
}
impl MatchDetails {
pub fn new() -> Self {
Self::default()
}
}
/// A request to find matches in the Radix Tree.
pub struct MatchRequest {
/// A vector of `LocalBlockHash` representing the sequence to match.
......@@ -279,6 +295,30 @@ impl MatchRequest {
}
}
/// A request to find matches while also returning continuation metadata.
pub struct MatchDetailsRequest {
/// A vector of `LocalBlockHash` representing the sequence to match.
pub sequence: Vec<LocalBlockHash>,
/// A boolean indicating whether to exit early if a single match is found.
pub early_exit: bool,
/// A channel sender to send the `MatchDetails` response.
pub resp: oneshot::Sender<MatchDetails>,
}
impl MatchDetailsRequest {
pub(super) fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: oneshot::Sender<MatchDetails>,
) -> Self {
Self {
sequence,
early_exit,
resp,
}
}
}
/// A request to dump the tree as events
pub struct DumpRequest {
/// Channel to send the dumped events
......
......@@ -51,7 +51,8 @@ pub use config::{
SharedCacheType,
};
pub use indexer::{
BranchShardedIndexer, MaybeError, SharedKvCache, SyncIndexer, ThreadPoolIndexer,
BranchShardedIndexer, LowerTierContinuation, LowerTierIndexer, MaybeError, SharedKvCache,
SyncIndexer, ThreadPoolIndexer,
};
pub use nested_map::PositionalIndexer;
pub use protocols::{
......
......@@ -355,9 +355,12 @@ pub struct WorkerSelectionResult {
/// The total number of blocks required to prefill the request
pub required_blocks: u64,
/// The number of blocks that the selected worker may already have cached.
/// This is not a guarantee, but an estimate.
pub overlap_blocks: u32,
/// Approximate effective cache hit on the selected worker in fractional blocks.
/// Use `.round() as u32` for a block-count approximation.
pub effective_overlap_blocks: f64,
/// Approximate cached-token count derived from the weighted cache hit.
pub cached_tokens: usize,
}
/// Active load metrics for a worker, used for busy detection.
......
......@@ -35,6 +35,14 @@ pub fn min_initial_workers_from_env() -> anyhow::Result<usize> {
}
}
const fn default_host_cache_hit_weight() -> f64 {
0.75
}
const fn default_disk_cache_hit_weight() -> f64 {
0.25
}
/// Type of external shared KV cache to query during routing.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
......@@ -170,6 +178,14 @@ pub struct KvRouterConfig {
#[validate(range(min = 0.0))]
pub overlap_score_weight: f64,
#[serde(default = "default_host_cache_hit_weight")]
#[validate(range(min = 0.0, max = 1.0))]
pub host_cache_hit_weight: f64,
#[serde(default = "default_disk_cache_hit_weight")]
#[validate(range(min = 0.0, max = 1.0))]
pub disk_cache_hit_weight: f64,
#[validate(range(min = 0.0))]
pub router_temperature: f64,
......@@ -269,6 +285,8 @@ impl Default for KvRouterConfig {
fn default() -> Self {
Self {
overlap_score_weight: 1.0,
host_cache_hit_weight: default_host_cache_hit_weight(),
disk_cache_hit_weight: default_disk_cache_hit_weight(),
router_temperature: 0.0,
use_kv_events: true,
durable_kv_events: false, // default to NATS Core (local indexer mode)
......
......@@ -14,8 +14,10 @@ use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::prefill_load::PrefillLoadEstimator;
use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank};
use super::types::{
KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse, TierOverlapBlocks,
};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{
ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
};
......@@ -42,6 +44,18 @@ where
S: SchedulingPolicy + 'static,
Sel: WorkerSelector<C> + Send + Sync + 'static,
{
fn worker_dp_range(workers: &HashMap<WorkerId, C>) -> HashMap<WorkerId, (u32, u32)> {
workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect()
}
#[allow(clippy::too_many_arguments)]
pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>,
......@@ -84,15 +98,7 @@ where
continue;
}
let dp_range: HashMap<WorkerId, (u32, u32)> = current_workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect();
let dp_range = Self::worker_dp_range(&current_workers);
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
......@@ -168,7 +174,10 @@ where
maybe_request_id: Option<String>,
isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
tier_overlap_blocks: TierOverlapBlocks,
effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
tree_sizes: HashMap<WorkerWithDpRank, usize>,
router_config_override: Option<&super::config::RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
......@@ -186,7 +195,10 @@ where
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
tier_overlap_blocks,
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes,
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens,
......@@ -258,7 +270,7 @@ where
&self,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
track_prefill_tokens: bool,
) -> Vec<PotentialLoad> {
let decay_now = Instant::now();
......@@ -267,7 +279,7 @@ where
.potential_blocks_and_tokens_with_prefill_tracking(
token_seq.as_deref(),
isl_tokens,
overlaps,
effective_cached_tokens,
track_prefill_tokens,
decay_now,
);
......@@ -306,7 +318,7 @@ mod tests {
use tokio::sync::{mpsc, watch};
use super::*;
use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores};
use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData};
use crate::scheduling::PrefillLoadEstimator;
use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector;
......@@ -423,7 +435,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
Some("adapter-a".to_string()),
......@@ -462,7 +477,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
Some(&crate::config::RouterConfigOverride {
track_prefill_tokens: Some(false),
..Default::default()
......@@ -489,6 +507,52 @@ mod tests {
cancel_token.cancel();
}
#[tokio::test]
async fn test_schedule_uses_weighted_cached_tokens_for_active_tracking() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let worker = WorkerWithDpRank::new(0, 0);
let response = scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
TierOverlapBlocks::default(),
HashMap::from([(worker, 0.75)]),
HashMap::from([(worker, 48)]),
HashMap::new(),
None,
true,
None,
0.0,
None,
None,
None,
None,
)
.await
.unwrap();
assert_eq!(response.best_worker, worker);
assert_eq!(response.cached_tokens, 48);
assert_eq!(response.effective_overlap_blocks, 0.75);
assert_eq!(
slots.active_tokens(Instant::now()).get(&worker).copied(),
Some(16),
"weighted cached tokens should reduce tracked prefill load",
);
cancel_token.cancel();
}
#[tokio::test]
async fn test_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new();
......@@ -507,7 +571,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -528,7 +595,10 @@ mod tests {
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -570,7 +640,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -591,7 +664,10 @@ mod tests {
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -647,7 +723,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -668,7 +747,10 @@ mod tests {
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -723,7 +805,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -744,7 +829,10 @@ mod tests {
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -797,7 +885,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
Some("adapter-a".to_string()),
......@@ -839,14 +930,10 @@ mod tests {
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default();
let cached_tokens = HashMap::new();
let (decode_blocks, prefill_tokens) = slots.potential_blocks_and_tokens(
Some(&token_seq),
128,
overlaps.clone(),
Instant::now(),
);
let (decode_blocks, prefill_tokens) =
slots.potential_blocks_and_tokens(Some(&token_seq), 128, cached_tokens.clone());
let mut expected: Vec<_> = decode_blocks
.keys()
.map(|worker| PotentialLoad {
......@@ -858,7 +945,7 @@ mod tests {
.collect();
expected.sort_by_key(|load| (load.worker_id, load.dp_rank));
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps, true);
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, cached_tokens, true);
actual.sort_by_key(|load| (load.worker_id, load.dp_rank));
assert_eq!(actual.len(), expected.len());
......@@ -899,7 +986,10 @@ mod tests {
Some("req-1".to_string()),
100,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -914,7 +1004,7 @@ mod tests {
tokio::time::advance(Duration::from_secs(6)).await;
let loads = scheduler.get_potential_loads(None, 0, OverlapScores::default(), true);
let loads = scheduler.get_potential_loads(None, 0, HashMap::new(), true);
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 40);
......@@ -927,7 +1017,7 @@ mod tests {
make_scheduler(HashMap::new(), None, false, None);
scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), true);
let loads = scheduler.get_potential_loads(None, 64, HashMap::new(), true);
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].worker_id, 42);
......@@ -944,7 +1034,7 @@ mod tests {
assert_eq!(
scheduler
.get_potential_loads(None, 64, OverlapScores::default(), true)
.get_potential_loads(None, 64, HashMap::new(), true,)
.len(),
1
);
......@@ -963,7 +1053,7 @@ mod tests {
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if scheduler
.get_potential_loads(None, 64, OverlapScores::default(), true)
.get_potential_loads(None, 64, HashMap::new(), true)
.len()
== 3
{
......@@ -995,7 +1085,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![11, 22]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -1008,7 +1101,7 @@ mod tests {
.await
.unwrap();
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), false);
let loads = scheduler.get_potential_loads(None, 64, HashMap::new(), false);
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 64);
......
......@@ -66,16 +66,34 @@ impl SchedulingPolicy for LcfsPolicy {
/// Optimizes for average TTFT — minimizes total weighted completion time
/// (Smith 1956). Short or high-priority requests are scheduled before
/// long low-priority ones, reducing mean latency across the batch.
pub struct WsptPolicy {
pub block_size: usize,
}
pub struct WsptPolicy;
impl SchedulingPolicy for WsptPolicy {
type Key = OrderedFloat<f64>;
fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
let weight = 1.0 + request.priority_jump.max(0.0);
let cached_tokens = request.overlap_blocks() as usize * self.block_size;
let allowed_ids = request.allowed_worker_ids.as_ref();
let cached_tokens = request.pinned_worker.map_or_else(
|| {
request
.effective_cached_tokens
.iter()
.filter(|(worker, _)| {
allowed_ids.is_none_or(|ids| ids.contains(&worker.worker_id))
})
.map(|(_, tokens)| *tokens)
.max()
.unwrap_or(0)
},
|worker| {
request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0)
},
);
let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1);
OrderedFloat(weight / new_tokens as f64)
}
......@@ -91,11 +109,11 @@ pub enum RouterSchedulingPolicy {
}
impl RouterSchedulingPolicy {
pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self {
pub fn new(kind: RouterQueuePolicy) -> Self {
match kind {
RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy),
}
}
}
......@@ -124,11 +142,24 @@ mod tests {
priority_jump: f64,
overlaps: OverlapScores,
) -> SchedulingRequest {
let effective_overlap_blocks = overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as f64))
.collect();
let effective_cached_tokens = overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as usize * 16))
.collect();
SchedulingRequest {
maybe_request_id: None,
token_seq: None,
isl_tokens,
overlaps,
tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes: std::collections::HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......@@ -224,10 +255,10 @@ mod tests {
let early = Duration::from_secs(1);
let late = Duration::from_secs(10);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs, 16);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs);
assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req));
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs, 16);
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs);
assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req));
}
......@@ -235,7 +266,7 @@ mod tests {
#[test]
fn wspt_shorter_request_scheduled_first() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let short = request_with(100, 0.0, OverlapScores::default());
let long = request_with(1000, 0.0, OverlapScores::default());
let t = Duration::ZERO;
......@@ -247,7 +278,7 @@ mod tests {
#[test]
fn wspt_overlap_reduces_effective_cost() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
// Both 1024 ISL tokens, but one has 60 blocks cached (960 tokens).
let no_cache = request_with(1024, 0.0, OverlapScores::default());
let cached = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
......@@ -262,7 +293,7 @@ mod tests {
#[test]
fn wspt_priority_promotes() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 5.0, OverlapScores::default());
let t = Duration::ZERO;
......@@ -274,7 +305,7 @@ mod tests {
#[test]
fn wspt_uses_max_overlap() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
// 4 workers with overlaps [10, 20, 50, 60]. max = 60.
// new_tokens = 1024 - 60*16 = 64
let req = request_with(
......@@ -289,7 +320,7 @@ mod tests {
#[test]
fn wspt_uses_pinned_worker_overlap_when_present() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60), (1, 1)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
......@@ -300,7 +331,7 @@ mod tests {
#[test]
fn wspt_missing_pinned_overlap_uses_zero() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
......@@ -311,7 +342,7 @@ mod tests {
#[test]
fn wspt_no_overlap_falls_back_to_isl() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let req = request_with(512, 0.0, OverlapScores::default());
let key = policy.enqueue_key(Duration::ZERO, &req);
let expected = OrderedFloat(1.0 / 512.0);
......@@ -320,7 +351,7 @@ mod tests {
#[test]
fn wspt_full_overlap_clamps_to_one() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
// 512 tokens, 64 blocks cached = 1024 cached tokens > ISL → saturating_sub → 0 → max(1)
let req = request_with(512, 0.0, overlaps_from(&[(0, 64)]));
let key = policy.enqueue_key(Duration::ZERO, &req);
......
......@@ -239,7 +239,7 @@ impl<
.potential_blocks_and_tokens_with_prefill_tracking(
request.token_seq.as_deref(),
request.isl_tokens,
request.overlaps.clone(),
request.effective_cached_tokens.clone(),
request.track_prefill_tokens,
decay_now,
);
......@@ -263,7 +263,8 @@ impl<
request.respond(Ok(SchedulingResponse {
best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks,
effective_overlap_blocks: selection.effective_overlap_blocks,
cached_tokens: selection.cached_tokens,
}));
if !request.update_states {
......@@ -277,7 +278,7 @@ impl<
let prefill_load_hint = self.prefill_load_hint_for(
request.isl_tokens,
selection.overlap_blocks,
selection.cached_tokens,
request.track_prefill_tokens,
);
......@@ -291,7 +292,7 @@ impl<
worker: selection.worker,
lora_name: request.lora_name.clone(),
},
decay_now,
Instant::now(),
) {
tracing::warn!("Failed to add request {request_id}: {e}");
}
......@@ -300,14 +301,14 @@ impl<
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
cached_tokens: usize,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let prefix = cached_tokens.min(isl_tokens);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
......@@ -408,7 +409,7 @@ mod tests {
use tokio::sync::{Barrier, watch};
use super::*;
use crate::protocols::{OverlapScores, WorkerSelectionResult, WorkerWithDpRank};
use crate::protocols::{WorkerSelectionResult, WorkerWithDpRank};
use crate::scheduling::types::KvSchedulerError;
use crate::sequences::ActiveSequencesMultiWorker;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
......@@ -499,7 +500,16 @@ mod tests {
Ok(WorkerSelectionResult {
worker,
required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64,
overlap_blocks: request.overlaps.scores.get(&worker).copied().unwrap_or(0),
effective_overlap_blocks: request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0),
cached_tokens: request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0),
})
}
}
......@@ -628,7 +638,10 @@ mod tests {
maybe_request_id: Some(request_id.to_string()),
token_seq: None,
isl_tokens,
overlaps: OverlapScores::default(),
tier_overlap_blocks: Default::default(),
effective_overlap_blocks: HashMap::new(),
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......@@ -1020,7 +1033,10 @@ mod tests {
maybe_request_id: Some("filter-0".to_string()),
token_seq: None,
isl_tokens: isl,
overlaps: OverlapScores::default(),
tier_overlap_blocks: Default::default(),
effective_overlap_blocks: HashMap::new(),
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......
......@@ -37,59 +37,49 @@ fn softmax_sample_with_sample(
temperature: f64,
sample: f64,
) -> (WorkerWithDpRank, f64) {
if logits.is_empty() {
panic!("Empty logits for softmax sampling");
}
assert!(!logits.is_empty(), "Empty logits for softmax sampling");
// Guard: at zero temperature, return a minimum-logit worker directly.
if temperature == 0.0 {
let mut logit_iter = logits.iter();
let (first_key, first_logit) = logit_iter.next().unwrap();
let mut min_logit = first_logit;
let mut min_key = first_key;
for (key, logit) in logit_iter {
if logit < min_logit {
min_logit = logit;
min_key = key;
}
}
return (*min_key, *min_logit);
let (worker, logit) = logits
.iter()
.min_by(|a, b| a.1.total_cmp(b.1))
.expect("logits non-empty");
return (*worker, *logit);
}
let entries: Vec<_> = logits
.iter()
.map(|(worker, logit)| (*worker, *logit))
.collect();
let values: Vec<_> = entries.iter().map(|(_, logit)| *logit).collect();
let entries: Vec<(WorkerWithDpRank, f64)> = logits.iter().map(|(w, l)| (*w, *l)).collect();
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let (min_val, max_val) = entries
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), (_, v)| {
(lo.min(*v), hi.max(*v))
});
let probabilities = if min_val == max_val {
let mut probs = if min_val == max_val {
vec![1.0 / entries.len() as f64; entries.len()]
} else {
// Fused normalize -> negate -> scale -> exp, then normalize probabilities
let range = max_val - min_val;
let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect();
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
let sum: f64 = probs.iter().sum();
probs.iter_mut().for_each(|p| *p /= sum);
probs
// Negate logits and rescale to [−1/temperature, 0] for numerical stability
// before softmax. Subtracting the max (which maps to min_val) keeps exp() inputs ≤ 0.
let scale = -1.0 / ((max_val - min_val) * temperature);
let max_scaled = min_val * scale;
entries
.iter()
.map(|(_, v)| (v * scale - max_scaled).exp())
.collect::<Vec<f64>>()
};
let sum: f64 = probs.iter().sum();
probs.iter_mut().for_each(|p| *p /= sum);
let mut cumsum = 0.0;
for (i, &prob) in probabilities.iter().enumerate() {
for (i, &prob) in probs.iter().enumerate() {
cumsum += prob;
if sample <= cumsum {
return entries[i];
}
}
// Fallback to last key (shouldn't normally reach here)
entries[entries.len() - 1]
*entries.last().unwrap()
}
/// Default implementation matching the Python _cost_function.
......@@ -99,12 +89,6 @@ pub struct DefaultWorkerSelector {
pub worker_type: &'static str,
}
#[derive(Debug, Clone, Copy)]
struct WorkerScore {
overlap_blocks: u32,
logit: f64,
}
impl DefaultWorkerSelector {
pub fn new(kv_router_config: Option<KvRouterConfig>, worker_type: &'static str) -> Self {
Self {
......@@ -113,7 +97,7 @@ impl DefaultWorkerSelector {
}
}
fn worker_score(
fn worker_logit(
&self,
request: &SchedulingRequest,
worker: WorkerWithDpRank,
......@@ -121,9 +105,16 @@ impl DefaultWorkerSelector {
overlap_weight: f64,
shared_cache_multiplier: f64,
formula_name: &'static str,
) -> WorkerScore {
) -> f64 {
let isl = request.isl_tokens;
let overlap_blocks = request.overlaps.scores.get(&worker).copied().unwrap_or(0);
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0);
// `shared_cache_hits::hits_beyond` expects an integer block count, so
// round the weighted overlap for this comparison only.
let device_overlap_blocks = effective_overlap_blocks.round().max(0.0) as u32;
let default_prefill_token = if request.track_prefill_tokens { isl } else { 0 };
let prefill_token = request
.prefill_tokens
......@@ -134,7 +125,7 @@ impl DefaultWorkerSelector {
// Adjust prefill tokens by shared cache hits beyond this worker's device prefix.
let (adjusted_prefill_token, shared_beyond) =
if let Some(ref shared_hits) = request.shared_cache_hits {
let beyond = shared_hits.hits_beyond(overlap_blocks);
let beyond = shared_hits.hits_beyond(device_overlap_blocks);
let reduction = shared_cache_multiplier * (beyond as f64) * (block_size as f64);
let adjusted = (prefill_token as f64 - reduction).max(0.0) as usize;
(adjusted, beyond)
......@@ -153,7 +144,7 @@ impl DefaultWorkerSelector {
if shared_beyond > 0 {
tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} device blocks, \
"{formula_name} for worker_id={} dp_rank={:?} with {effective_overlap_blocks:.2} effective device blocks, \
{shared_beyond} shared blocks beyond device (multiplier={shared_cache_multiplier:.2}): {logit:.3} \
= {overlap_weight:.1} * adjusted_prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3} \
......@@ -163,7 +154,7 @@ impl DefaultWorkerSelector {
);
} else {
tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} cached blocks: {logit:.3} \
"{formula_name} for worker_id={} dp_rank={:?} with {effective_overlap_blocks:.2} effective cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
......@@ -171,10 +162,7 @@ impl DefaultWorkerSelector {
);
}
WorkerScore {
overlap_blocks,
logit,
}
logit
}
}
......@@ -201,7 +189,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let isl = request.isl_tokens;
let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores;
let overlap_weight = request
.router_config_override
......@@ -218,7 +205,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
if let Some(worker) = pinned_worker {
pinned_worker_config(workers, worker)?;
let score = self.worker_score(
let logit = self.worker_logit(
request,
worker,
block_size,
......@@ -226,11 +213,31 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
shared_cache_multiplier,
"Pinned formula",
);
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0);
let cached_tokens = request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0);
tracing::info!(
"Selected pinned worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, effective cached blocks: {:.2}",
self.worker_type,
worker.worker_id,
worker.dp_rank,
logit,
effective_overlap_blocks,
);
return Ok(WorkerSelectionResult {
worker,
required_blocks: request_blocks as u64,
overlap_blocks: score.overlap_blocks,
effective_overlap_blocks,
cached_tokens,
});
}
......@@ -241,7 +248,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.unwrap_or(self.kv_router_config.router_temperature);
let get_score = |worker: WorkerWithDpRank| -> f64 {
self.worker_score(
self.worker_logit(
request,
worker,
block_size,
......@@ -249,7 +256,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
shared_cache_multiplier,
"Formula",
)
.logit
};
let worker_iter = workers
......@@ -282,7 +288,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
);
let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = min_workers
.iter()
.map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
.map(|w| (request.tree_sizes.get(w).copied().unwrap_or(0), w))
.collect();
if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
......@@ -305,22 +311,58 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
softmax_sample(&worker_logits, temperature)
};
let best_host_pinned_overlap_blocks = request
.tier_overlap_blocks
.host_pinned
.get(&best_worker)
.copied()
.unwrap_or(0);
let best_disk_overlap_blocks = request
.tier_overlap_blocks
.disk
.get(&best_worker)
.copied()
.unwrap_or(0);
if self.worker_type == "decode" {
tracing::info!(
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}",
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, host_pinned blocks: {}, disk blocks: {}",
self.worker_type,
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
best_host_pinned_overlap_blocks,
best_disk_overlap_blocks,
);
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&best_worker)
.copied()
.unwrap_or(0.0);
let cached_tokens = request
.effective_cached_tokens
.get(&best_worker)
.copied()
.unwrap_or(0);
return Ok(WorkerSelectionResult {
worker: best_worker,
required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
effective_overlap_blocks,
cached_tokens,
});
}
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
let best_overlap = request
.effective_overlap_blocks
.get(&best_worker)
.copied()
.unwrap_or(0.0);
let best_cached_tokens = request
.effective_cached_tokens
.get(&best_worker)
.copied()
.unwrap_or(0);
let total_blocks_info = workers
.get(&best_worker.worker_id)
......@@ -328,20 +370,17 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default();
let tree_size = request
.overlaps
.tree_sizes
.get(&best_worker)
.copied()
.unwrap_or(0);
let tree_size = request.tree_sizes.get(&best_worker).copied().unwrap_or(0);
tracing::info!(
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, effective cached blocks: {:.2}, host_pinned blocks: {}, disk blocks: {}, tree size: {}{}",
self.worker_type,
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
best_overlap,
best_host_pinned_overlap_blocks,
best_disk_overlap_blocks,
tree_size,
total_blocks_info
);
......@@ -349,7 +388,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
Ok(WorkerSelectionResult {
worker: best_worker,
required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
effective_overlap_blocks: best_overlap,
cached_tokens: best_cached_tokens,
})
}
}
......@@ -479,7 +519,6 @@ mod tests {
/// Worker 1 has lower logit (less work), so it wins.
#[test]
fn test_shared_cache_hits_scoring() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig;
let block_size = 1u32;
......@@ -487,8 +526,8 @@ mod tests {
let worker0 = WorkerWithDpRank::from_worker_id(0);
let worker1 = WorkerWithDpRank::from_worker_id(1);
let mut overlaps = OverlapScores::new();
overlaps.scores.insert(worker0, 2);
let mut effective_overlap_blocks = HashMap::new();
effective_overlap_blocks.insert(worker0, 2.0);
// worker1 has 0 overlap (not in map)
#[allow(clippy::single_range_in_vec_init)]
......@@ -511,7 +550,10 @@ mod tests {
maybe_request_id: Some("test".into()),
token_seq: None,
isl_tokens: isl,
overlaps,
tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......@@ -540,15 +582,14 @@ mod tests {
/// Without shared cache hits, the scoring should be unchanged.
#[test]
fn test_no_shared_cache_unchanged() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig;
let block_size = 16u32;
let isl = 64usize;
let worker0 = WorkerWithDpRank::from_worker_id(0);
let mut overlaps = OverlapScores::new();
overlaps.scores.insert(worker0, 2);
let mut effective_overlap_blocks = HashMap::new();
effective_overlap_blocks.insert(worker0, 2.0);
let config = KvRouterConfig::default();
let selector = DefaultWorkerSelector::new(Some(config), "test");
......@@ -560,7 +601,10 @@ mod tests {
maybe_request_id: Some("test".into()),
token_seq: None,
isl_tokens: isl,
overlaps,
tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......
......@@ -8,9 +8,13 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::config::RouterConfigOverride;
use crate::protocols::{
DpRank, OverlapScores, SharedCacheHits, WorkerConfigLike, WorkerId, WorkerWithDpRank,
};
use crate::protocols::{DpRank, SharedCacheHits, WorkerConfigLike, WorkerId, WorkerWithDpRank};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TierOverlapBlocks {
pub host_pinned: HashMap<WorkerWithDpRank, usize>,
pub disk: HashMap<WorkerWithDpRank, usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
......@@ -38,14 +42,20 @@ pub enum KvSchedulerError {
#[derive(Debug)]
pub struct SchedulingResponse {
pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32,
pub effective_overlap_blocks: f64,
pub cached_tokens: usize,
}
pub struct SchedulingRequest {
pub maybe_request_id: Option<String>,
pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize,
pub overlaps: OverlapScores,
pub tier_overlap_blocks: TierOverlapBlocks,
pub effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
pub effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
/// Per-worker tree size, used only for tie-breaking when multiple workers
/// produce the same logit at temperature=0.
pub tree_sizes: HashMap<WorkerWithDpRank, usize>,
pub decode_blocks: FxHashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: FxHashMap<WorkerWithDpRank, usize>,
pub track_prefill_tokens: bool,
......@@ -85,16 +95,6 @@ impl SchedulingRequest {
})
}
/// Scheduling consumers use the exact pinned-worker overlap when present;
/// otherwise they use the best available overlap across eligible workers.
pub fn overlap_blocks(&self) -> u32 {
if let Some(worker) = self.pinned_worker {
return self.overlaps.scores.get(&worker).copied().unwrap_or(0);
}
self.overlaps.scores.values().copied().max().unwrap_or(0)
}
pub fn bypass_capacity_check(&self) -> bool {
self.pinned_worker.is_none() && self.allowed_worker_ids.is_some()
}
......
......@@ -11,7 +11,7 @@
use dynamo_tokens::SequenceHash;
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
use rustc_hash::{FxBuildHasher, FxHashMap};
use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
......@@ -26,8 +26,7 @@ use super::request_maps::RequestIndex;
use super::single::{ActiveSequences, PromptMembershipDelta, RequestId};
use super::topology::WorkerTable;
use crate::protocols::{
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, PrefillLoadHint,
WorkerWithDpRank,
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, PrefillLoadHint, WorkerWithDpRank,
};
// How often we force expire stale requests across all workers. See the comment
......@@ -90,6 +89,9 @@ pub enum SequenceError {
#[error("Request {request_id} not found")]
RequestNotFound { request_id: String },
#[error("Failed to publish replica-sync event: {0}")]
ReplicaSyncPublishFailed(String),
}
/// Bundled parameters for adding a request to the sequence tracker.
......@@ -587,8 +589,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlaps: OverlapScores,
decay_now: Instant,
cached_tokens: HashMap<WorkerWithDpRank, usize>,
) -> (
FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>,
......@@ -596,9 +597,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
overlaps,
cached_tokens,
true,
decay_now,
Instant::now(),
)
}
......@@ -606,22 +607,54 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlaps: OverlapScores,
cached_tokens: HashMap<WorkerWithDpRank, usize>,
track_prefill_tokens: bool,
decay_now: Instant,
) -> (
FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>,
) {
self.prompt_registry
.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
&overlaps,
track_prefill_tokens,
self.block_size,
decay_now,
)
#[cfg(feature = "bench")]
let start = tokio::time::Instant::now();
let table = self.workers.read();
#[cfg(feature = "bench")]
let num_workers = table.slots.len();
let mut potential_blocks =
FxHashMap::with_capacity_and_hasher(table.slots.len(), FxBuildHasher);
let mut potential_tokens =
FxHashMap::with_capacity_and_hasher(table.slots.len(), FxBuildHasher);
for slot in &table.slots {
let worker_cached_tokens = cached_tokens.get(&slot.worker).copied().unwrap_or(0);
let (blocks, tokens) = slot
.sequences
.read()
.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
worker_cached_tokens,
track_prefill_tokens,
decay_now,
);
potential_blocks.insert(slot.worker, blocks);
potential_tokens.insert(slot.worker, tokens);
}
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
num_workers,
total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed"
);
}
(potential_blocks, potential_tokens)
}
/// Query all workers for their current number of active blocks.
......@@ -945,7 +978,6 @@ mod tests {
use rustc_hash::FxHashMap;
use super::super::prefill_tracker::added_prefill_tokens;
use super::*;
use crate::protocols::{
ActiveSequenceEvent, ActiveSequenceEventData, BlockHashOptions, OverlapScores,
......@@ -999,6 +1031,7 @@ mod tests {
FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>,
) {
let cached_tokens = cached_tokens_from_overlap_scores(overlaps, sequences.block_size);
let table = sequences.workers.read();
let mut potential_blocks = FxHashMap::default();
let mut potential_tokens = FxHashMap::default();
......@@ -1013,9 +1046,9 @@ mod tests {
});
let new_blocks =
token_sequence.map_or(0, |query| query.len().saturating_sub(overlap_depth));
let overlap = *overlaps.scores.get(&slot.worker).unwrap_or(&0);
let worker_cached_tokens = *cached_tokens.get(&slot.worker).unwrap_or(&0);
let added_tokens = if track_prefill_tokens {
added_prefill_tokens(sequences.block_size, isl, overlap)
seq.new_tokens(isl, worker_cached_tokens)
} else {
0
};
......@@ -1025,6 +1058,17 @@ mod tests {
(potential_blocks, potential_tokens)
}
fn cached_tokens_from_overlap_scores(
overlaps: &OverlapScores,
block_size: usize,
) -> HashMap<WorkerWithDpRank, usize> {
overlaps
.scores
.iter()
.map(|(worker, overlap_blocks)| (*worker, (*overlap_blocks as usize) * block_size))
.collect()
}
fn seq_hashes_for_tokens(tokens: &[u32], lora_name: Option<&str>) -> Vec<SequenceHash> {
seq_hashes_for_tokens_with_block_size(tokens, 4, lora_name)
}
......@@ -1153,7 +1197,7 @@ mod tests {
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&prompt),
16,
actual_overlaps,
cached_tokens_from_overlap_scores(&actual_overlaps, sequences.block_size),
true,
decay_now,
);
......@@ -1214,7 +1258,7 @@ mod tests {
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&base_prompt),
8,
OverlapScores::default(),
HashMap::new(),
false,
decay_now,
);
......@@ -1278,7 +1322,7 @@ mod tests {
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&prompt_b),
3,
OverlapScores::default(),
cached_tokens_from_overlap_scores(&OverlapScores::default(), sequences.block_size),
false,
decay_now,
);
......@@ -1299,7 +1343,7 @@ mod tests {
let actual_after_free = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&prompt_b),
3,
OverlapScores::default(),
cached_tokens_from_overlap_scores(&OverlapScores::default(), sequences.block_size),
false,
decay_now,
);
......@@ -1390,7 +1434,7 @@ mod tests {
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&[1, 2, 3]),
12,
OverlapScores::default(),
HashMap::new(),
false,
Instant::now(),
);
......@@ -1595,7 +1639,7 @@ mod tests {
let (_, potential_tokens) = sequences.potential_blocks_and_tokens_with_prefill_tracking(
None,
0,
OverlapScores::default(),
HashMap::new(),
false,
decay_now,
);
......
......@@ -51,6 +51,7 @@ impl PrefillLoadSnapshot {
}
}
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn added_prefill_tokens(block_size: usize, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * block_size;
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
......
......@@ -534,6 +534,7 @@ impl PromptMembershipTrie {
}
}
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn compute_overlap_depths(
&self,
query: Option<&[SequenceHash]>,
......
......@@ -97,6 +97,7 @@ impl PromptRegistry {
}
#[expect(clippy::too_many_arguments)]
#[cfg_attr(not(test), allow(dead_code))]
fn project_loads_from_overlap(
&self,
query_len: usize,
......@@ -135,6 +136,7 @@ impl PromptRegistry {
(potential_blocks, potential_tokens)
}
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn potential_blocks_and_tokens_with_prefill_tracking(
&self,
token_sequence: Option<&[SequenceHash]>,
......
......@@ -109,8 +109,6 @@ pub struct ActiveSequences {
requests: HashMap<RequestId, RequestState>,
prefill: PrefillLoadTracker,
blocks: BlockTracker,
#[cfg(test)]
block_size: usize,
last_expiry_check_time: Instant,
}
......@@ -123,8 +121,6 @@ impl ActiveSequences {
requests: HashMap::new(),
prefill: PrefillLoadTracker::default(),
blocks: BlockTracker::default(),
#[cfg(test)]
block_size,
last_expiry_check_time: Instant::now(),
}
}
......@@ -157,14 +153,12 @@ impl ActiveSequences {
self.blocks.active_blocks()
}
#[cfg(test)]
pub(super) fn active_tokens(&self, decay_now: Instant) -> usize {
self.prefill.snapshot().active_tokens_at(decay_now)
}
/// Add a new request with optional prompt-token load accounting.
/// Returns block membership transitions plus any expired request IDs removed during cleanup.
#[allow(clippy::too_many_arguments)]
pub(super) fn add_request_with_prefill_tracking(
&mut self,
request_id: RequestId,
......@@ -299,6 +293,31 @@ impl ActiveSequences {
membership_delta
}
pub fn new_tokens(&self, isl: usize, cached_tokens: usize) -> usize {
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens}, returning 0"
);
0
})
}
pub fn potential_blocks_and_tokens(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
cached_tokens: usize,
decay_now: Instant,
) -> (usize, usize) {
self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
cached_tokens,
true,
decay_now,
)
}
/// Add an output block with a random hash and optional fractional decay weight.
///
/// This is used during generation to track output blocks as they are created.
......@@ -330,12 +349,11 @@ impl ActiveSequences {
acquire.became_present_on_worker.then_some(random_hash)
}
#[cfg(test)]
fn potential_blocks_and_tokens_with_prefill_tracking(
pub fn potential_blocks_and_tokens_with_prefill_tracking(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlap: u32,
cached_tokens: usize,
track_prefill_tokens: bool,
decay_now: Instant,
) -> (usize, usize) {
......@@ -346,7 +364,7 @@ impl ActiveSequences {
};
let active_tokens = self.active_tokens(decay_now);
let potential_tokens = if track_prefill_tokens {
added_prefill_tokens(self.block_size, isl, overlap) + active_tokens
self.new_tokens(isl, cached_tokens) + active_tokens
} else {
active_tokens
};
......
......@@ -5,8 +5,9 @@
//!
//! - This module is responsible for maintaining a registry of all blocks currently within a pool.
//! This consists of two components: A global registry of all blocks, and a per-pool registry of blocks.
//! - The global registry is a mapping of sequences hashes to registration handles. If two blocks in different pools
//! have the same sequence hash, then they will share the same registration handle. The global registry is shared across all pools.
//! - The global registry is keyed by sequence hash and storage tier. If two blocks in different pools
//! have the same sequence hash but live in different tiers, they keep distinct registration handles
//! so KVBM can emit per-tier events. The global registry is shared across all pools.
//! - The per-pool registry is a mapping of sequence hashes to block handles. This is used to track which blocks are
//! currently within a specific pool. The block handle is unique across pools, and is used to track the block's lifetime.
//! - When a block is in the registered state, it has a unique block handle and a possibly shared registration handle.
......@@ -27,12 +28,28 @@ use std::{
use super::super::events::{EventManager, EventReleaseManager, PublishHandle};
use super::state::BlockState;
use crate::block_manager::kv_consolidator::StorageTier;
use crate::tokens::{BlockHash, SequenceHash, TokenBlock};
use derive_getters::Getters;
use tokio::{runtime::Handle, sync::mpsc};
pub type GlobalRegistry = Arc<Mutex<HashMap<SequenceHash, Weak<RegistrationHandle>>>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RegistrationKey {
sequence_hash: SequenceHash,
storage_tier: StorageTier,
}
impl RegistrationKey {
fn new(sequence_hash: SequenceHash, storage_tier: StorageTier) -> Self {
Self {
sequence_hash,
storage_tier,
}
}
}
pub type GlobalRegistry = Arc<Mutex<HashMap<RegistrationKey, Weak<RegistrationHandle>>>>;
#[derive(Debug, thiserror::Error)]
pub enum BlockRegistrationError {
......@@ -72,6 +89,7 @@ impl Drop for BlockHandle {
pub struct BlockRegistry {
blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>>,
storage_tier: StorageTier,
event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
......@@ -82,6 +100,7 @@ impl BlockRegistry {
event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry,
async_runtime: Handle,
storage_tier: StorageTier,
) -> Self {
let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel();
......@@ -105,17 +124,19 @@ impl BlockRegistry {
}
let mut global_registry = global_registry.lock().unwrap();
let registration_key = RegistrationKey::new(sequence_hash, storage_tier);
if let Some(entry) = global_registry.get(&sequence_hash)
if let Some(entry) = global_registry.get(&registration_key)
&& entry.upgrade().is_none()
{
global_registry.remove(&sequence_hash);
global_registry.remove(&registration_key);
}
}
});
Self {
blocks,
storage_tier,
event_manager,
global_registry,
unregister_tx,
......@@ -165,9 +186,10 @@ impl BlockRegistry {
let reg_handle = 'reg_block: {
// Now, check the global registry.
let mut global_registry = self.global_registry.lock().unwrap();
let registration_key = RegistrationKey::new(sequence_hash, self.storage_tier);
// If an identical block exists in other pool, use the same registration handle.
if let Some(handle) = global_registry.get(&sequence_hash)
if let Some(handle) = global_registry.get(&registration_key)
&& let Some(handle) = handle.upgrade()
{
break 'reg_block handle;
......@@ -177,11 +199,12 @@ impl BlockRegistry {
publish_handle = Some(Self::create_publish_handle(
state.token_block(),
self.event_manager.clone(),
self.storage_tier,
));
let reg_handle = publish_handle.as_ref().unwrap().remove_handle();
// Insert the registration handle into the global registry.
global_registry.insert(sequence_hash, Arc::downgrade(&reg_handle));
global_registry.insert(registration_key, Arc::downgrade(&reg_handle));
reg_handle
};
......@@ -205,8 +228,10 @@ impl BlockRegistry {
fn create_publish_handle(
token_block: &TokenBlock,
event_manager: Arc<dyn EventManager>,
storage_tier: StorageTier,
) -> PublishHandle {
let reg_handle = RegistrationHandle::from_token_block(token_block, event_manager.clone());
let reg_handle =
RegistrationHandle::from_token_block(token_block, event_manager.clone(), storage_tier);
PublishHandle::new(reg_handle, event_manager)
}
......@@ -223,6 +248,15 @@ pub struct RegistrationHandle {
#[getter(copy)]
parent_sequence_hash: Option<SequenceHash>,
#[getter(copy)]
external_sequence_hash: Option<SequenceHash>,
#[getter(copy)]
external_parent_sequence_hash: Option<SequenceHash>,
#[getter(copy)]
storage_tier: StorageTier,
#[getter(skip)]
release_manager: Arc<dyn EventReleaseManager>,
......@@ -240,14 +274,29 @@ impl RegistrationHandle {
self.token_block.tokens()
}
/// Returns the router-facing sequence hash for this block.
pub fn published_sequence_hash(&self) -> SequenceHash {
self.external_sequence_hash.unwrap_or(self.sequence_hash)
}
/// Returns the router-facing parent sequence hash for this block.
pub fn published_parent_sequence_hash(&self) -> Option<SequenceHash> {
self.external_parent_sequence_hash
.or(self.parent_sequence_hash)
}
fn from_token_block(
token_block: &TokenBlock,
release_manager: Arc<dyn EventReleaseManager>,
storage_tier: StorageTier,
) -> Self {
Self {
block_hash: token_block.block_hash(),
sequence_hash: token_block.sequence_hash(),
parent_sequence_hash: token_block.parent_sequence_hash(),
external_sequence_hash: token_block.external_sequence_hash(),
external_parent_sequence_hash: token_block.external_parent_sequence_hash(),
storage_tier,
release_manager,
token_block: token_block.clone(),
}
......@@ -258,8 +307,13 @@ impl std::fmt::Debug for RegistrationHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RegistrationHandle {{ sequence_hash: {}; block_hash: {}; parent_sequence_hash: {:?} }}",
self.sequence_hash, self.block_hash, self.parent_sequence_hash
"RegistrationHandle {{ sequence_hash: {}; block_hash: {}; parent_sequence_hash: {:?}; external_sequence_hash: {:?}; external_parent_sequence_hash: {:?}; storage_tier: {:?} }}",
self.sequence_hash,
self.block_hash,
self.parent_sequence_hash,
self.external_sequence_hash,
self.external_parent_sequence_hash,
self.storage_tier
)
}
}
......@@ -274,7 +328,9 @@ impl Drop for RegistrationHandle {
mod tests {
use super::*;
use crate::block_manager::events::NullEventManager;
use crate::block_manager::events::tests::{EventType, MockEventManager};
use crate::block_manager::kv_consolidator::StorageTier;
use crate::tokens::{TokenBlockSequence, Tokens};
fn create_sequence() -> TokenBlockSequence {
......@@ -303,8 +359,11 @@ mod tests {
let (event_manager, mut rx) = MockEventManager::new();
let publish_handle =
BlockRegistry::create_publish_handle(&sequence.blocks()[0], event_manager.clone());
let publish_handle = BlockRegistry::create_publish_handle(
&sequence.blocks()[0],
event_manager.clone(),
StorageTier::Device,
);
// no event should have been triggered
assert!(rx.try_recv().is_err());
......@@ -317,7 +376,7 @@ mod tests {
assert_eq!(events.len(), 1);
assert_eq!(
events[0],
EventType::Register(sequence.blocks()[0].sequence_hash())
EventType::Register(sequence.blocks()[0].sequence_hash(), StorageTier::Device)
);
// the second event should be a Remove event
......@@ -325,7 +384,7 @@ mod tests {
assert_eq!(events.len(), 1);
assert_eq!(
events[0],
EventType::Remove(sequence.blocks()[0].sequence_hash())
EventType::Remove(sequence.blocks()[0].sequence_hash(), StorageTier::Device)
);
// there should be no more events
......@@ -340,8 +399,11 @@ mod tests {
let (event_manager, mut rx) = MockEventManager::new();
let publish_handle =
BlockRegistry::create_publish_handle(block_to_test, event_manager.clone());
let publish_handle = BlockRegistry::create_publish_handle(
block_to_test,
event_manager.clone(),
StorageTier::Device,
);
// Remove the registration handle before dropping the publish handle
let reg_handle = publish_handle.remove_handle();
......@@ -359,7 +421,7 @@ mod tests {
);
assert_eq!(
register_events[0],
EventType::Register(expected_sequence_hash),
EventType::Register(expected_sequence_hash, StorageTier::Device),
"Expected Register event"
);
......@@ -370,7 +432,7 @@ mod tests {
assert_eq!(events.len(), 1);
assert_eq!(
events[0],
EventType::Remove(expected_sequence_hash),
EventType::Remove(expected_sequence_hash, StorageTier::Device),
"Only Remove event should be triggered"
);
......@@ -389,8 +451,16 @@ mod tests {
let (event_manager, mut rx) = MockEventManager::new();
let mut publisher = event_manager.publisher();
let publish_handle1 = BlockRegistry::create_publish_handle(block1, event_manager.clone());
let publish_handle2 = BlockRegistry::create_publish_handle(block2, event_manager.clone());
let publish_handle1 = BlockRegistry::create_publish_handle(
block1,
event_manager.clone(),
StorageTier::Device,
);
let publish_handle2 = BlockRegistry::create_publish_handle(
block2,
event_manager.clone(),
StorageTier::Device,
);
// Remove handles before adding to publisher
let reg_handle1 = publish_handle1.remove_handle();
......@@ -413,8 +483,8 @@ mod tests {
"Should receive two Register events in one batch"
);
// Order isn't guaranteed, so check for both
assert!(events.contains(&EventType::Register(hash1)));
assert!(events.contains(&EventType::Register(hash2)));
assert!(events.contains(&EventType::Register(hash1, StorageTier::Device)));
assert!(events.contains(&EventType::Register(hash2, StorageTier::Device)));
// no more events immediately after publish
assert!(rx.try_recv().is_err());
......@@ -423,12 +493,12 @@ mod tests {
drop(reg_handle1);
let events1 = rx.try_recv().unwrap();
assert_eq!(events1.len(), 1);
assert_eq!(events1[0], EventType::Remove(hash1));
assert_eq!(events1[0], EventType::Remove(hash1, StorageTier::Device));
drop(reg_handle2);
let events2 = rx.try_recv().unwrap();
assert_eq!(events2.len(), 1);
assert_eq!(events2[0], EventType::Remove(hash2));
assert_eq!(events2[0], EventType::Remove(hash2, StorageTier::Device));
// no more events
assert!(rx.try_recv().is_err());
......@@ -453,7 +523,11 @@ mod tests {
let (event_manager, mut rx) = MockEventManager::new();
let mut publisher = event_manager.publisher();
let publish_handle1 = BlockRegistry::create_publish_handle(block1, event_manager.clone());
let publish_handle1 = BlockRegistry::create_publish_handle(
block1,
event_manager.clone(),
StorageTier::Device,
);
publisher.take_handle(publish_handle1);
......@@ -461,7 +535,7 @@ mod tests {
publisher.publish();
let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0], EventType::Register(hash1));
assert_eq!(events[0], EventType::Register(hash1, StorageTier::Device));
// The RegistrationHandle Arc was taken by the publisher and dropped after the publish call
// So, the Remove event should follow immediately.
......@@ -473,7 +547,7 @@ mod tests {
);
assert_eq!(
remove_events[0],
EventType::Remove(hash1),
EventType::Remove(hash1, StorageTier::Device),
"Expected Remove event"
);
......@@ -485,4 +559,89 @@ mod tests {
drop(publisher);
assert!(rx.try_recv().is_err());
}
#[tokio::test(flavor = "current_thread")]
async fn test_same_sequence_in_different_tiers_emits_distinct_events() {
let sequence = create_sequence();
let block = sequence.blocks()[0].clone();
let sequence_hash = block.sequence_hash();
let (event_manager, mut rx) = MockEventManager::new();
let global_registry = GlobalRegistry::default();
let mut host_registry = BlockRegistry::new(
event_manager.clone(),
global_registry.clone(),
Handle::current(),
StorageTier::HostPinned,
);
let mut disk_registry = BlockRegistry::new(
event_manager.clone(),
global_registry,
Handle::current(),
StorageTier::Disk,
);
let mut host_state = BlockState::Reset;
host_state.apply_token_block(block.clone()).unwrap();
let host_publish = host_registry
.register_block(&mut host_state)
.unwrap()
.unwrap();
drop(host_publish);
assert_eq!(
rx.recv().await.unwrap(),
vec![EventType::Register(sequence_hash, StorageTier::HostPinned)]
);
let mut disk_state = BlockState::Reset;
disk_state.apply_token_block(block).unwrap();
let disk_publish = disk_registry
.register_block(&mut disk_state)
.unwrap()
.unwrap();
drop(disk_publish);
assert_eq!(
rx.recv().await.unwrap(),
vec![EventType::Register(sequence_hash, StorageTier::Disk)]
);
drop(host_state);
assert_eq!(
rx.recv().await.unwrap(),
vec![EventType::Remove(sequence_hash, StorageTier::HostPinned)]
);
drop(disk_state);
assert_eq!(
rx.recv().await.unwrap(),
vec![EventType::Remove(sequence_hash, StorageTier::Disk)]
);
}
#[test]
fn test_registration_handle_prefers_external_hashes_for_publication() {
let mut sequence = create_sequence();
sequence.sync_external_sequence_hashes(&[50_001, 50_002]);
let release_manager = NullEventManager::new();
let registration_handle = RegistrationHandle::from_token_block(
&sequence.blocks()[1],
release_manager,
StorageTier::HostPinned,
);
assert_eq!(registration_handle.external_sequence_hash(), Some(50_002));
assert_eq!(
registration_handle.external_parent_sequence_hash(),
Some(50_001)
);
assert_eq!(registration_handle.published_sequence_hash(), 50_002);
assert_eq!(
registration_handle.published_parent_sequence_hash(),
Some(50_001)
);
}
}
......@@ -230,6 +230,7 @@ impl KvBlockManagerConfigBuilder {
engine_endpoint: String,
output_endpoint: Option<String>,
engine_source: crate::block_manager::kv_consolidator::EventSource,
mode: crate::block_manager::kv_consolidator::KvEventConsolidationMode,
) -> Self {
let config = match engine_source {
crate::block_manager::kv_consolidator::EventSource::Vllm => {
......@@ -237,6 +238,7 @@ impl KvBlockManagerConfigBuilder {
crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_vllm(
engine_endpoint,
output_ep,
mode,
)
}
crate::block_manager::kv_consolidator::EventSource::Trtllm => {
......@@ -248,6 +250,7 @@ impl KvBlockManagerConfigBuilder {
crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_trtllm(
engine_endpoint,
output_ep,
mode,
)
}
crate::block_manager::kv_consolidator::EventSource::Kvbm => {
......
......@@ -201,18 +201,21 @@ impl DynamoEventManager {
rt.spawn(async move {
for handle in handles {
// Extract block metadata from RegistrationHandle
let block_hash = handle.sequence_hash().to_string();
let parent_hash = handle.parent_sequence_hash().map(|h| h.to_string());
let block_hash = handle.published_sequence_hash().to_string();
let parent_hash = handle
.published_parent_sequence_hash()
.map(|h| h.to_string());
// Extract block_size and tokens from RegistrationHandle
let block_size = handle.block_size(); // usize
let tokens: Vec<u32> = handle.tokens().iter().copied().collect();
tracing::debug!(
"DynamoEventManager sending store event to kv event consolidator: block_hash={}, block_size={}, tokens={}",
"DynamoEventManager sending store event to kv event consolidator: block_hash={}, block_size={}, tokens={}, tier={:?}",
block_hash,
block_size,
tokens.len()
tokens.len(),
handle.storage_tier()
);
// Send to consolidator with EventSource::Kvbm
......@@ -224,7 +227,7 @@ impl DynamoEventManager {
parent_hash,
block_size,
None, // lora_name
None, // tier
Some(handle.storage_tier()),
None, // data_parallel_rank
)
.await;
......@@ -242,11 +245,13 @@ impl DynamoEventManager {
///
/// Called when a RegistrationHandle is dropped (block evicted from KVBM).
fn publish_remove_event(&self, registration_handle: &RegistrationHandle) {
let block_hash = registration_handle.sequence_hash().to_string();
let block_hash = registration_handle.published_sequence_hash().to_string();
let tier = registration_handle.storage_tier();
tracing::debug!(
"DynamoEventManager::publish_remove_event called: block_hash={}",
block_hash
%block_hash,
?tier,
"DynamoEventManager sending remove event to kv event consolidator"
);
let kv_event_consolidator = self.consolidator_handle.clone();
......@@ -254,7 +259,7 @@ impl DynamoEventManager {
if let Ok(rt) = tokio::runtime::Handle::try_current() {
rt.spawn(async move {
kv_event_consolidator
.handle_remove(&block_hash, EventSource::Kvbm)
.handle_remove(&block_hash, EventSource::Kvbm, Some(tier))
.await;
});
} else {
......@@ -288,14 +293,15 @@ impl EventReleaseManager for DynamoEventManager {
#[cfg(test)]
pub mod tests {
use crate::block_manager::kv_consolidator::StorageTier;
use crate::tokens::SequenceHash;
use super::*;
#[derive(Debug, PartialEq, Eq)]
pub enum EventType {
Register(SequenceHash),
Remove(SequenceHash),
Register(SequenceHash, StorageTier),
Remove(SequenceHash, StorageTier),
}
pub struct MockEventManager {
......@@ -322,7 +328,7 @@ pub mod tests {
fn publish(&self, handles: Vec<Arc<RegistrationHandle>>) {
let events = handles
.iter()
.map(|handle| EventType::Register(handle.sequence_hash()))
.map(|handle| EventType::Register(handle.sequence_hash(), handle.storage_tier()))
.collect::<Vec<_>>();
self.tx.send(events).unwrap();
}
......@@ -330,7 +336,10 @@ pub mod tests {
impl EventReleaseManager for MockEventManager {
fn block_release(&self, registration_handle: &RegistrationHandle) {
let events = vec![EventType::Remove(registration_handle.sequence_hash())];
let events = vec![EventType::Remove(
registration_handle.sequence_hash(),
registration_handle.storage_tier(),
)];
self.tx.send(events).unwrap();
}
}
......
......@@ -7,6 +7,35 @@ use serde::{Deserialize, Serialize};
use super::tracker::EventSource;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum KvEventConsolidationMode {
#[default]
Dedup,
Passthrough,
}
impl KvEventConsolidationMode {
pub fn as_str(self) -> &'static str {
match self {
Self::Dedup => "dedup",
Self::Passthrough => "passthrough",
}
}
}
impl std::str::FromStr for KvEventConsolidationMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim().to_ascii_lowercase().as_str() {
"dedup" => Ok(Self::Dedup),
"passthrough" => Ok(Self::Passthrough),
_ => Err(format!("Unknown KV event consolidator mode: {s}")),
}
}
}
/// Configuration for the KV Event Consolidator
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvEventConsolidatorConfig {
......@@ -19,6 +48,9 @@ pub struct KvEventConsolidatorConfig {
/// Engine source for events (vLLM or TensorRT-LLM)
pub engine_source: EventSource,
/// How the consolidator should process store/remove events.
pub mode: KvEventConsolidationMode,
}
impl Default for KvEventConsolidatorConfig {
......@@ -27,6 +59,7 @@ impl Default for KvEventConsolidatorConfig {
engine_event_endpoint: "tcp://localhost:5557".to_string(),
consolidated_event_endpoint: "tcp://*:5558".to_string(),
engine_source: EventSource::Vllm,
mode: KvEventConsolidationMode::Dedup,
}
}
}
......@@ -36,29 +69,41 @@ impl KvEventConsolidatorConfig {
engine_event_endpoint: String,
consolidated_event_endpoint: String,
engine_source: EventSource,
mode: KvEventConsolidationMode,
) -> Self {
Self {
engine_event_endpoint,
consolidated_event_endpoint,
engine_source,
mode,
}
}
/// Create config for vLLM
pub fn new_vllm(engine_event_endpoint: String, consolidated_event_endpoint: String) -> Self {
pub fn new_vllm(
engine_event_endpoint: String,
consolidated_event_endpoint: String,
mode: KvEventConsolidationMode,
) -> Self {
Self {
engine_event_endpoint,
consolidated_event_endpoint,
engine_source: EventSource::Vllm,
mode,
}
}
/// Create config for TensorRT-LLM
pub fn new_trtllm(engine_event_endpoint: String, consolidated_event_endpoint: String) -> Self {
pub fn new_trtllm(
engine_event_endpoint: String,
consolidated_event_endpoint: String,
mode: KvEventConsolidationMode,
) -> Self {
Self {
engine_event_endpoint,
consolidated_event_endpoint,
engine_source: EventSource::Trtllm,
mode,
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment