Unverified Commit e3d00b89 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Tier-based KV Routing (#8380)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 7e48f3bd
......@@ -44,6 +44,7 @@ fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) {
}
mod kv_indexer;
mod local;
mod lower_tier;
mod metrics;
mod thread_pool;
mod traits;
......@@ -62,6 +63,7 @@ mod tests;
pub use branch_sharded::*;
pub use kv_indexer::*;
pub use local::*;
pub use lower_tier::*;
pub use metrics::*;
pub use thread_pool::*;
pub use traits::*;
......
......@@ -23,7 +23,7 @@ use std::{
use rustc_hash::{FxHashMap, FxHashSet};
use super::{EventWarningKind, PreBoundEventCounters};
use super::{EventWarningKind, MatchDetails, PreBoundEventCounters};
use crate::active_set::reconcile_active_workers;
use crate::protocols::*;
......@@ -162,12 +162,20 @@ impl RadixTree {
///
/// ### Returns
///
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new();
/// A `MatchDetails` representing overlap scores plus continuation state.
pub fn find_match_details(
&self,
sequence: Vec<LocalBlockHash>,
early_exit: bool,
) -> MatchDetails {
let mut details = MatchDetails::new();
let MatchDetails {
overlap_scores: scores,
last_matched_hashes,
} = &mut details;
if sequence.is_empty() {
return scores;
return details;
}
let now = Instant::now();
......@@ -184,7 +192,7 @@ impl RadixTree {
};
let Some(first_child) = first_child else {
return scores;
return details;
};
// Initialize active worker set from first child.
......@@ -208,12 +216,18 @@ impl RadixTree {
}
if active.is_empty() {
return scores;
return details;
}
let mut current_hash = first_child
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
if early_exit && active_count == 1 {
for worker in &active {
scores.scores.insert(*worker, 1);
last_matched_hashes.insert(*worker, current_hash);
}
for worker in scores.scores.keys() {
let tree_size = self
......@@ -223,7 +237,7 @@ impl RadixTree {
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
return scores;
return details;
}
let mut current = first_child;
......@@ -256,6 +270,7 @@ impl RadixTree {
if child_count != active_count {
reconcile_active_workers(&mut active, &borrow.workers, |worker| {
scores.scores.insert(worker, matched_depth);
last_matched_hashes.insert(worker, current_hash);
});
active_count = active.len();
}
......@@ -281,9 +296,17 @@ impl RadixTree {
if early_exit && active_count == 1 {
matched_depth = (idx + 1) as u32;
current_hash = block
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
break;
}
current_hash = block
.borrow()
.block_hash
.expect("matched radix node must have a block hash");
current = block;
matched_depth = (idx + 1) as u32;
}
......@@ -291,6 +314,7 @@ impl RadixTree {
// Record scores for workers that survived through the deepest matched level.
for worker in &active {
scores.scores.insert(*worker, matched_depth);
last_matched_hashes.insert(*worker, current_hash);
}
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
......@@ -305,7 +329,12 @@ impl RadixTree {
scores.tree_sizes.insert(*worker, tree_size);
}
scores
details
}
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
self.find_match_details(sequence, early_exit).overlap_scores
}
/// Apply a [`RouterEvent`] to the radix tree.
......
......@@ -9,6 +9,7 @@ use tokio::sync::oneshot;
use crate::protocols::*;
use dynamo_tokens::SequenceHash;
use rustc_hash::FxHashMap;
/// Trait for types that may represent an error response.
/// Used for RPC-style responses that can indicate success or failure.
......@@ -250,6 +251,21 @@ impl dynamo_runtime::protocols::maybe_error::MaybeError for IndexerRecordRouting
}
}
/// Rich non-wire query result for router-local device tier lookups.
#[derive(Debug, Clone, Default)]
pub struct MatchDetails {
/// Existing overlap scores used by scheduling.
pub overlap_scores: OverlapScores,
/// Last matched device sequence hash per worker, used to seed lower-tier queries.
pub last_matched_hashes: FxHashMap<WorkerWithDpRank, ExternalSequenceBlockHash>,
}
impl MatchDetails {
pub fn new() -> Self {
Self::default()
}
}
/// A request to find matches in the Radix Tree.
pub struct MatchRequest {
/// A vector of `LocalBlockHash` representing the sequence to match.
......@@ -279,6 +295,30 @@ impl MatchRequest {
}
}
/// A request to find matches while also returning continuation metadata.
pub struct MatchDetailsRequest {
/// A vector of `LocalBlockHash` representing the sequence to match.
pub sequence: Vec<LocalBlockHash>,
/// A boolean indicating whether to exit early if a single match is found.
pub early_exit: bool,
/// A channel sender to send the `MatchDetails` response.
pub resp: oneshot::Sender<MatchDetails>,
}
impl MatchDetailsRequest {
pub(super) fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: oneshot::Sender<MatchDetails>,
) -> Self {
Self {
sequence,
early_exit,
resp,
}
}
}
/// A request to dump the tree as events
pub struct DumpRequest {
/// Channel to send the dumped events
......
......@@ -51,7 +51,8 @@ pub use config::{
SharedCacheType,
};
pub use indexer::{
BranchShardedIndexer, MaybeError, SharedKvCache, SyncIndexer, ThreadPoolIndexer,
BranchShardedIndexer, LowerTierContinuation, LowerTierIndexer, MaybeError, SharedKvCache,
SyncIndexer, ThreadPoolIndexer,
};
pub use nested_map::PositionalIndexer;
pub use protocols::{
......
......@@ -355,9 +355,12 @@ pub struct WorkerSelectionResult {
/// The total number of blocks required to prefill the request
pub required_blocks: u64,
/// The number of blocks that the selected worker may already have cached.
/// This is not a guarantee, but an estimate.
pub overlap_blocks: u32,
/// Approximate effective cache hit on the selected worker in fractional blocks.
/// Use `.round() as u32` for a block-count approximation.
pub effective_overlap_blocks: f64,
/// Approximate cached-token count derived from the weighted cache hit.
pub cached_tokens: usize,
}
/// Active load metrics for a worker, used for busy detection.
......
......@@ -35,6 +35,14 @@ pub fn min_initial_workers_from_env() -> anyhow::Result<usize> {
}
}
const fn default_host_cache_hit_weight() -> f64 {
0.75
}
const fn default_disk_cache_hit_weight() -> f64 {
0.25
}
/// Type of external shared KV cache to query during routing.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
......@@ -170,6 +178,14 @@ pub struct KvRouterConfig {
#[validate(range(min = 0.0))]
pub overlap_score_weight: f64,
#[serde(default = "default_host_cache_hit_weight")]
#[validate(range(min = 0.0, max = 1.0))]
pub host_cache_hit_weight: f64,
#[serde(default = "default_disk_cache_hit_weight")]
#[validate(range(min = 0.0, max = 1.0))]
pub disk_cache_hit_weight: f64,
#[validate(range(min = 0.0))]
pub router_temperature: f64,
......@@ -269,6 +285,8 @@ impl Default for KvRouterConfig {
fn default() -> Self {
Self {
overlap_score_weight: 1.0,
host_cache_hit_weight: default_host_cache_hit_weight(),
disk_cache_hit_weight: default_disk_cache_hit_weight(),
router_temperature: 0.0,
use_kv_events: true,
durable_kv_events: false, // default to NATS Core (local indexer mode)
......
......@@ -14,8 +14,10 @@ use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::prefill_load::PrefillLoadEstimator;
use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank};
use super::types::{
KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse, TierOverlapBlocks,
};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{
ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
};
......@@ -42,6 +44,18 @@ where
S: SchedulingPolicy + 'static,
Sel: WorkerSelector<C> + Send + Sync + 'static,
{
fn worker_dp_range(workers: &HashMap<WorkerId, C>) -> HashMap<WorkerId, (u32, u32)> {
workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect()
}
#[allow(clippy::too_many_arguments)]
pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>,
......@@ -84,15 +98,7 @@ where
continue;
}
let dp_range: HashMap<WorkerId, (u32, u32)> = current_workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect();
let dp_range = Self::worker_dp_range(&current_workers);
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
......@@ -168,7 +174,10 @@ where
maybe_request_id: Option<String>,
isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
tier_overlap_blocks: TierOverlapBlocks,
effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
tree_sizes: HashMap<WorkerWithDpRank, usize>,
router_config_override: Option<&super::config::RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
......@@ -186,7 +195,10 @@ where
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
tier_overlap_blocks,
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes,
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens,
......@@ -258,7 +270,7 @@ where
&self,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
track_prefill_tokens: bool,
) -> Vec<PotentialLoad> {
let decay_now = Instant::now();
......@@ -267,7 +279,7 @@ where
.potential_blocks_and_tokens_with_prefill_tracking(
token_seq.as_deref(),
isl_tokens,
overlaps,
effective_cached_tokens,
track_prefill_tokens,
decay_now,
);
......@@ -306,7 +318,7 @@ mod tests {
use tokio::sync::{mpsc, watch};
use super::*;
use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores};
use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData};
use crate::scheduling::PrefillLoadEstimator;
use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector;
......@@ -423,7 +435,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
Some("adapter-a".to_string()),
......@@ -462,7 +477,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
Some(&crate::config::RouterConfigOverride {
track_prefill_tokens: Some(false),
..Default::default()
......@@ -489,6 +507,52 @@ mod tests {
cancel_token.cancel();
}
#[tokio::test]
async fn test_schedule_uses_weighted_cached_tokens_for_active_tracking() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let worker = WorkerWithDpRank::new(0, 0);
let response = scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
TierOverlapBlocks::default(),
HashMap::from([(worker, 0.75)]),
HashMap::from([(worker, 48)]),
HashMap::new(),
None,
true,
None,
0.0,
None,
None,
None,
None,
)
.await
.unwrap();
assert_eq!(response.best_worker, worker);
assert_eq!(response.cached_tokens, 48);
assert_eq!(response.effective_overlap_blocks, 0.75);
assert_eq!(
slots.active_tokens(Instant::now()).get(&worker).copied(),
Some(16),
"weighted cached tokens should reduce tracked prefill load",
);
cancel_token.cancel();
}
#[tokio::test]
async fn test_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new();
......@@ -507,7 +571,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -528,7 +595,10 @@ mod tests {
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -570,7 +640,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -591,7 +664,10 @@ mod tests {
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -647,7 +723,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -668,7 +747,10 @@ mod tests {
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -723,7 +805,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -744,7 +829,10 @@ mod tests {
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -797,7 +885,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
Some("adapter-a".to_string()),
......@@ -839,14 +930,10 @@ mod tests {
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default();
let cached_tokens = HashMap::new();
let (decode_blocks, prefill_tokens) = slots.potential_blocks_and_tokens(
Some(&token_seq),
128,
overlaps.clone(),
Instant::now(),
);
let (decode_blocks, prefill_tokens) =
slots.potential_blocks_and_tokens(Some(&token_seq), 128, cached_tokens.clone());
let mut expected: Vec<_> = decode_blocks
.keys()
.map(|worker| PotentialLoad {
......@@ -858,7 +945,7 @@ mod tests {
.collect();
expected.sort_by_key(|load| (load.worker_id, load.dp_rank));
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps, true);
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, cached_tokens, true);
actual.sort_by_key(|load| (load.worker_id, load.dp_rank));
assert_eq!(actual.len(), expected.len());
......@@ -899,7 +986,10 @@ mod tests {
Some("req-1".to_string()),
100,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -914,7 +1004,7 @@ mod tests {
tokio::time::advance(Duration::from_secs(6)).await;
let loads = scheduler.get_potential_loads(None, 0, OverlapScores::default(), true);
let loads = scheduler.get_potential_loads(None, 0, HashMap::new(), true);
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 40);
......@@ -927,7 +1017,7 @@ mod tests {
make_scheduler(HashMap::new(), None, false, None);
scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), true);
let loads = scheduler.get_potential_loads(None, 64, HashMap::new(), true);
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].worker_id, 42);
......@@ -944,7 +1034,7 @@ mod tests {
assert_eq!(
scheduler
.get_potential_loads(None, 64, OverlapScores::default(), true)
.get_potential_loads(None, 64, HashMap::new(), true,)
.len(),
1
);
......@@ -963,7 +1053,7 @@ mod tests {
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if scheduler
.get_potential_loads(None, 64, OverlapScores::default(), true)
.get_potential_loads(None, 64, HashMap::new(), true)
.len()
== 3
{
......@@ -995,7 +1085,10 @@ mod tests {
Some("req-1".to_string()),
64,
Some(vec![11, 22]),
OverlapScores::default(),
TierOverlapBlocks::default(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
true,
None,
......@@ -1008,7 +1101,7 @@ mod tests {
.await
.unwrap();
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), false);
let loads = scheduler.get_potential_loads(None, 64, HashMap::new(), false);
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 64);
......
......@@ -66,16 +66,34 @@ impl SchedulingPolicy for LcfsPolicy {
/// Optimizes for average TTFT — minimizes total weighted completion time
/// (Smith 1956). Short or high-priority requests are scheduled before
/// long low-priority ones, reducing mean latency across the batch.
pub struct WsptPolicy {
pub block_size: usize,
}
pub struct WsptPolicy;
impl SchedulingPolicy for WsptPolicy {
type Key = OrderedFloat<f64>;
fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
let weight = 1.0 + request.priority_jump.max(0.0);
let cached_tokens = request.overlap_blocks() as usize * self.block_size;
let allowed_ids = request.allowed_worker_ids.as_ref();
let cached_tokens = request.pinned_worker.map_or_else(
|| {
request
.effective_cached_tokens
.iter()
.filter(|(worker, _)| {
allowed_ids.is_none_or(|ids| ids.contains(&worker.worker_id))
})
.map(|(_, tokens)| *tokens)
.max()
.unwrap_or(0)
},
|worker| {
request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0)
},
);
let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1);
OrderedFloat(weight / new_tokens as f64)
}
......@@ -91,11 +109,11 @@ pub enum RouterSchedulingPolicy {
}
impl RouterSchedulingPolicy {
pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self {
pub fn new(kind: RouterQueuePolicy) -> Self {
match kind {
RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy),
}
}
}
......@@ -124,11 +142,24 @@ mod tests {
priority_jump: f64,
overlaps: OverlapScores,
) -> SchedulingRequest {
let effective_overlap_blocks = overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as f64))
.collect();
let effective_cached_tokens = overlaps
.scores
.iter()
.map(|(worker, overlap)| (*worker, *overlap as usize * 16))
.collect();
SchedulingRequest {
maybe_request_id: None,
token_seq: None,
isl_tokens,
overlaps,
tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens,
tree_sizes: std::collections::HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......@@ -224,10 +255,10 @@ mod tests {
let early = Duration::from_secs(1);
let late = Duration::from_secs(10);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs, 16);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs);
assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req));
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs, 16);
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs);
assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req));
}
......@@ -235,7 +266,7 @@ mod tests {
#[test]
fn wspt_shorter_request_scheduled_first() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let short = request_with(100, 0.0, OverlapScores::default());
let long = request_with(1000, 0.0, OverlapScores::default());
let t = Duration::ZERO;
......@@ -247,7 +278,7 @@ mod tests {
#[test]
fn wspt_overlap_reduces_effective_cost() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
// Both 1024 ISL tokens, but one has 60 blocks cached (960 tokens).
let no_cache = request_with(1024, 0.0, OverlapScores::default());
let cached = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
......@@ -262,7 +293,7 @@ mod tests {
#[test]
fn wspt_priority_promotes() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 5.0, OverlapScores::default());
let t = Duration::ZERO;
......@@ -274,7 +305,7 @@ mod tests {
#[test]
fn wspt_uses_max_overlap() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
// 4 workers with overlaps [10, 20, 50, 60]. max = 60.
// new_tokens = 1024 - 60*16 = 64
let req = request_with(
......@@ -289,7 +320,7 @@ mod tests {
#[test]
fn wspt_uses_pinned_worker_overlap_when_present() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60), (1, 1)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
......@@ -300,7 +331,7 @@ mod tests {
#[test]
fn wspt_missing_pinned_overlap_uses_zero() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
......@@ -311,7 +342,7 @@ mod tests {
#[test]
fn wspt_no_overlap_falls_back_to_isl() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
let req = request_with(512, 0.0, OverlapScores::default());
let key = policy.enqueue_key(Duration::ZERO, &req);
let expected = OrderedFloat(1.0 / 512.0);
......@@ -320,7 +351,7 @@ mod tests {
#[test]
fn wspt_full_overlap_clamps_to_one() {
let policy = WsptPolicy { block_size: 16 };
let policy = WsptPolicy;
// 512 tokens, 64 blocks cached = 1024 cached tokens > ISL → saturating_sub → 0 → max(1)
let req = request_with(512, 0.0, overlaps_from(&[(0, 64)]));
let key = policy.enqueue_key(Duration::ZERO, &req);
......
......@@ -239,7 +239,7 @@ impl<
.potential_blocks_and_tokens_with_prefill_tracking(
request.token_seq.as_deref(),
request.isl_tokens,
request.overlaps.clone(),
request.effective_cached_tokens.clone(),
request.track_prefill_tokens,
decay_now,
);
......@@ -263,7 +263,8 @@ impl<
request.respond(Ok(SchedulingResponse {
best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks,
effective_overlap_blocks: selection.effective_overlap_blocks,
cached_tokens: selection.cached_tokens,
}));
if !request.update_states {
......@@ -277,7 +278,7 @@ impl<
let prefill_load_hint = self.prefill_load_hint_for(
request.isl_tokens,
selection.overlap_blocks,
selection.cached_tokens,
request.track_prefill_tokens,
);
......@@ -291,7 +292,7 @@ impl<
worker: selection.worker,
lora_name: request.lora_name.clone(),
},
decay_now,
Instant::now(),
) {
tracing::warn!("Failed to add request {request_id}: {e}");
}
......@@ -300,14 +301,14 @@ impl<
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
cached_tokens: usize,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let prefix = cached_tokens.min(isl_tokens);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
......@@ -408,7 +409,7 @@ mod tests {
use tokio::sync::{Barrier, watch};
use super::*;
use crate::protocols::{OverlapScores, WorkerSelectionResult, WorkerWithDpRank};
use crate::protocols::{WorkerSelectionResult, WorkerWithDpRank};
use crate::scheduling::types::KvSchedulerError;
use crate::sequences::ActiveSequencesMultiWorker;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
......@@ -499,7 +500,16 @@ mod tests {
Ok(WorkerSelectionResult {
worker,
required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64,
overlap_blocks: request.overlaps.scores.get(&worker).copied().unwrap_or(0),
effective_overlap_blocks: request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0),
cached_tokens: request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0),
})
}
}
......@@ -628,7 +638,10 @@ mod tests {
maybe_request_id: Some(request_id.to_string()),
token_seq: None,
isl_tokens,
overlaps: OverlapScores::default(),
tier_overlap_blocks: Default::default(),
effective_overlap_blocks: HashMap::new(),
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......@@ -1020,7 +1033,10 @@ mod tests {
maybe_request_id: Some("filter-0".to_string()),
token_seq: None,
isl_tokens: isl,
overlaps: OverlapScores::default(),
tier_overlap_blocks: Default::default(),
effective_overlap_blocks: HashMap::new(),
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......
......@@ -37,59 +37,49 @@ fn softmax_sample_with_sample(
temperature: f64,
sample: f64,
) -> (WorkerWithDpRank, f64) {
if logits.is_empty() {
panic!("Empty logits for softmax sampling");
}
assert!(!logits.is_empty(), "Empty logits for softmax sampling");
// Guard: at zero temperature, return a minimum-logit worker directly.
if temperature == 0.0 {
let mut logit_iter = logits.iter();
let (first_key, first_logit) = logit_iter.next().unwrap();
let mut min_logit = first_logit;
let mut min_key = first_key;
for (key, logit) in logit_iter {
if logit < min_logit {
min_logit = logit;
min_key = key;
}
let (worker, logit) = logits
.iter()
.min_by(|a, b| a.1.total_cmp(b.1))
.expect("logits non-empty");
return (*worker, *logit);
}
return (*min_key, *min_logit);
}
let entries: Vec<(WorkerWithDpRank, f64)> = logits.iter().map(|(w, l)| (*w, *l)).collect();
let entries: Vec<_> = logits
let (min_val, max_val) = entries
.iter()
.map(|(worker, logit)| (*worker, *logit))
.collect();
let values: Vec<_> = entries.iter().map(|(_, logit)| *logit).collect();
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), (_, v)| {
(lo.min(*v), hi.max(*v))
});
let probabilities = if min_val == max_val {
let mut probs = if min_val == max_val {
vec![1.0 / entries.len() as f64; entries.len()]
} else {
// Fused normalize -> negate -> scale -> exp, then normalize probabilities
let range = max_val - min_val;
let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect();
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
// Negate logits and rescale to [−1/temperature, 0] for numerical stability
// before softmax. Subtracting the max (which maps to min_val) keeps exp() inputs ≤ 0.
let scale = -1.0 / ((max_val - min_val) * temperature);
let max_scaled = min_val * scale;
entries
.iter()
.map(|(_, v)| (v * scale - max_scaled).exp())
.collect::<Vec<f64>>()
};
let sum: f64 = probs.iter().sum();
probs.iter_mut().for_each(|p| *p /= sum);
probs
};
let mut cumsum = 0.0;
for (i, &prob) in probabilities.iter().enumerate() {
for (i, &prob) in probs.iter().enumerate() {
cumsum += prob;
if sample <= cumsum {
return entries[i];
}
}
// Fallback to last key (shouldn't normally reach here)
entries[entries.len() - 1]
*entries.last().unwrap()
}
/// Default implementation matching the Python _cost_function.
......@@ -99,12 +89,6 @@ pub struct DefaultWorkerSelector {
pub worker_type: &'static str,
}
#[derive(Debug, Clone, Copy)]
struct WorkerScore {
overlap_blocks: u32,
logit: f64,
}
impl DefaultWorkerSelector {
pub fn new(kv_router_config: Option<KvRouterConfig>, worker_type: &'static str) -> Self {
Self {
......@@ -113,7 +97,7 @@ impl DefaultWorkerSelector {
}
}
fn worker_score(
fn worker_logit(
&self,
request: &SchedulingRequest,
worker: WorkerWithDpRank,
......@@ -121,9 +105,16 @@ impl DefaultWorkerSelector {
overlap_weight: f64,
shared_cache_multiplier: f64,
formula_name: &'static str,
) -> WorkerScore {
) -> f64 {
let isl = request.isl_tokens;
let overlap_blocks = request.overlaps.scores.get(&worker).copied().unwrap_or(0);
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0);
// `shared_cache_hits::hits_beyond` expects an integer block count, so
// round the weighted overlap for this comparison only.
let device_overlap_blocks = effective_overlap_blocks.round().max(0.0) as u32;
let default_prefill_token = if request.track_prefill_tokens { isl } else { 0 };
let prefill_token = request
.prefill_tokens
......@@ -134,7 +125,7 @@ impl DefaultWorkerSelector {
// Adjust prefill tokens by shared cache hits beyond this worker's device prefix.
let (adjusted_prefill_token, shared_beyond) =
if let Some(ref shared_hits) = request.shared_cache_hits {
let beyond = shared_hits.hits_beyond(overlap_blocks);
let beyond = shared_hits.hits_beyond(device_overlap_blocks);
let reduction = shared_cache_multiplier * (beyond as f64) * (block_size as f64);
let adjusted = (prefill_token as f64 - reduction).max(0.0) as usize;
(adjusted, beyond)
......@@ -153,7 +144,7 @@ impl DefaultWorkerSelector {
if shared_beyond > 0 {
tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} device blocks, \
"{formula_name} for worker_id={} dp_rank={:?} with {effective_overlap_blocks:.2} effective device blocks, \
{shared_beyond} shared blocks beyond device (multiplier={shared_cache_multiplier:.2}): {logit:.3} \
= {overlap_weight:.1} * adjusted_prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3} \
......@@ -163,7 +154,7 @@ impl DefaultWorkerSelector {
);
} else {
tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} cached blocks: {logit:.3} \
"{formula_name} for worker_id={} dp_rank={:?} with {effective_overlap_blocks:.2} effective cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
......@@ -171,10 +162,7 @@ impl DefaultWorkerSelector {
);
}
WorkerScore {
overlap_blocks,
logit,
}
logit
}
}
......@@ -201,7 +189,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let isl = request.isl_tokens;
let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores;
let overlap_weight = request
.router_config_override
......@@ -218,7 +205,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
if let Some(worker) = pinned_worker {
pinned_worker_config(workers, worker)?;
let score = self.worker_score(
let logit = self.worker_logit(
request,
worker,
block_size,
......@@ -226,11 +213,31 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
shared_cache_multiplier,
"Pinned formula",
);
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&worker)
.copied()
.unwrap_or(0.0);
let cached_tokens = request
.effective_cached_tokens
.get(&worker)
.copied()
.unwrap_or(0);
tracing::info!(
"Selected pinned worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, effective cached blocks: {:.2}",
self.worker_type,
worker.worker_id,
worker.dp_rank,
logit,
effective_overlap_blocks,
);
return Ok(WorkerSelectionResult {
worker,
required_blocks: request_blocks as u64,
overlap_blocks: score.overlap_blocks,
effective_overlap_blocks,
cached_tokens,
});
}
......@@ -241,7 +248,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.unwrap_or(self.kv_router_config.router_temperature);
let get_score = |worker: WorkerWithDpRank| -> f64 {
self.worker_score(
self.worker_logit(
request,
worker,
block_size,
......@@ -249,7 +256,6 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
shared_cache_multiplier,
"Formula",
)
.logit
};
let worker_iter = workers
......@@ -282,7 +288,7 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
);
let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = min_workers
.iter()
.map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
.map(|w| (request.tree_sizes.get(w).copied().unwrap_or(0), w))
.collect();
if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
......@@ -305,22 +311,58 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
softmax_sample(&worker_logits, temperature)
};
let best_host_pinned_overlap_blocks = request
.tier_overlap_blocks
.host_pinned
.get(&best_worker)
.copied()
.unwrap_or(0);
let best_disk_overlap_blocks = request
.tier_overlap_blocks
.disk
.get(&best_worker)
.copied()
.unwrap_or(0);
if self.worker_type == "decode" {
tracing::info!(
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}",
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, host_pinned blocks: {}, disk blocks: {}",
self.worker_type,
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
best_host_pinned_overlap_blocks,
best_disk_overlap_blocks,
);
let effective_overlap_blocks = request
.effective_overlap_blocks
.get(&best_worker)
.copied()
.unwrap_or(0.0);
let cached_tokens = request
.effective_cached_tokens
.get(&best_worker)
.copied()
.unwrap_or(0);
return Ok(WorkerSelectionResult {
worker: best_worker,
required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
effective_overlap_blocks,
cached_tokens,
});
}
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
let best_overlap = request
.effective_overlap_blocks
.get(&best_worker)
.copied()
.unwrap_or(0.0);
let best_cached_tokens = request
.effective_cached_tokens
.get(&best_worker)
.copied()
.unwrap_or(0);
let total_blocks_info = workers
.get(&best_worker.worker_id)
......@@ -328,20 +370,17 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default();
let tree_size = request
.overlaps
.tree_sizes
.get(&best_worker)
.copied()
.unwrap_or(0);
let tree_size = request.tree_sizes.get(&best_worker).copied().unwrap_or(0);
tracing::info!(
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
"Selected worker: worker_type={}, worker_id={} dp_rank={:?}, logit: {:.3}, effective cached blocks: {:.2}, host_pinned blocks: {}, disk blocks: {}, tree size: {}{}",
self.worker_type,
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
best_overlap,
best_host_pinned_overlap_blocks,
best_disk_overlap_blocks,
tree_size,
total_blocks_info
);
......@@ -349,7 +388,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
Ok(WorkerSelectionResult {
worker: best_worker,
required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
effective_overlap_blocks: best_overlap,
cached_tokens: best_cached_tokens,
})
}
}
......@@ -479,7 +519,6 @@ mod tests {
/// Worker 1 has lower logit (less work), so it wins.
#[test]
fn test_shared_cache_hits_scoring() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig;
let block_size = 1u32;
......@@ -487,8 +526,8 @@ mod tests {
let worker0 = WorkerWithDpRank::from_worker_id(0);
let worker1 = WorkerWithDpRank::from_worker_id(1);
let mut overlaps = OverlapScores::new();
overlaps.scores.insert(worker0, 2);
let mut effective_overlap_blocks = HashMap::new();
effective_overlap_blocks.insert(worker0, 2.0);
// worker1 has 0 overlap (not in map)
#[allow(clippy::single_range_in_vec_init)]
......@@ -511,7 +550,10 @@ mod tests {
maybe_request_id: Some("test".into()),
token_seq: None,
isl_tokens: isl,
overlaps,
tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......@@ -540,15 +582,14 @@ mod tests {
/// Without shared cache hits, the scoring should be unchanged.
#[test]
fn test_no_shared_cache_unchanged() {
use crate::protocols::OverlapScores;
use crate::test_utils::SimpleWorkerConfig;
let block_size = 16u32;
let isl = 64usize;
let worker0 = WorkerWithDpRank::from_worker_id(0);
let mut overlaps = OverlapScores::new();
overlaps.scores.insert(worker0, 2);
let mut effective_overlap_blocks = HashMap::new();
effective_overlap_blocks.insert(worker0, 2.0);
let config = KvRouterConfig::default();
let selector = DefaultWorkerSelector::new(Some(config), "test");
......@@ -560,7 +601,10 @@ mod tests {
maybe_request_id: Some("test".into()),
token_seq: None,
isl_tokens: isl,
overlaps,
tier_overlap_blocks: Default::default(),
effective_overlap_blocks,
effective_cached_tokens: HashMap::new(),
tree_sizes: HashMap::new(),
decode_blocks: FxHashMap::default(),
prefill_tokens: FxHashMap::default(),
track_prefill_tokens: true,
......
......@@ -8,9 +8,13 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::config::RouterConfigOverride;
use crate::protocols::{
DpRank, OverlapScores, SharedCacheHits, WorkerConfigLike, WorkerId, WorkerWithDpRank,
};
use crate::protocols::{DpRank, SharedCacheHits, WorkerConfigLike, WorkerId, WorkerWithDpRank};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TierOverlapBlocks {
pub host_pinned: HashMap<WorkerWithDpRank, usize>,
pub disk: HashMap<WorkerWithDpRank, usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
......@@ -38,14 +42,20 @@ pub enum KvSchedulerError {
#[derive(Debug)]
pub struct SchedulingResponse {
pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32,
pub effective_overlap_blocks: f64,
pub cached_tokens: usize,
}
pub struct SchedulingRequest {
pub maybe_request_id: Option<String>,
pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize,
pub overlaps: OverlapScores,
pub tier_overlap_blocks: TierOverlapBlocks,
pub effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
pub effective_cached_tokens: HashMap<WorkerWithDpRank, usize>,
/// Per-worker tree size, used only for tie-breaking when multiple workers
/// produce the same logit at temperature=0.
pub tree_sizes: HashMap<WorkerWithDpRank, usize>,
pub decode_blocks: FxHashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: FxHashMap<WorkerWithDpRank, usize>,
pub track_prefill_tokens: bool,
......@@ -85,16 +95,6 @@ impl SchedulingRequest {
})
}
/// Scheduling consumers use the exact pinned-worker overlap when present;
/// otherwise they use the best available overlap across eligible workers.
pub fn overlap_blocks(&self) -> u32 {
if let Some(worker) = self.pinned_worker {
return self.overlaps.scores.get(&worker).copied().unwrap_or(0);
}
self.overlaps.scores.values().copied().max().unwrap_or(0)
}
pub fn bypass_capacity_check(&self) -> bool {
self.pinned_worker.is_none() && self.allowed_worker_ids.is_some()
}
......
This diff is collapsed.
......@@ -51,6 +51,7 @@ impl PrefillLoadSnapshot {
}
}
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn added_prefill_tokens(block_size: usize, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * block_size;
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
......
......@@ -534,6 +534,7 @@ impl PromptMembershipTrie {
}
}
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn compute_overlap_depths(
&self,
query: Option<&[SequenceHash]>,
......
......@@ -97,6 +97,7 @@ impl PromptRegistry {
}
#[expect(clippy::too_many_arguments)]
#[cfg_attr(not(test), allow(dead_code))]
fn project_loads_from_overlap(
&self,
query_len: usize,
......@@ -135,6 +136,7 @@ impl PromptRegistry {
(potential_blocks, potential_tokens)
}
#[cfg_attr(not(test), allow(dead_code))]
pub(super) fn potential_blocks_and_tokens_with_prefill_tracking(
&self,
token_sequence: Option<&[SequenceHash]>,
......
......@@ -109,8 +109,6 @@ pub struct ActiveSequences {
requests: HashMap<RequestId, RequestState>,
prefill: PrefillLoadTracker,
blocks: BlockTracker,
#[cfg(test)]
block_size: usize,
last_expiry_check_time: Instant,
}
......@@ -123,8 +121,6 @@ impl ActiveSequences {
requests: HashMap::new(),
prefill: PrefillLoadTracker::default(),
blocks: BlockTracker::default(),
#[cfg(test)]
block_size,
last_expiry_check_time: Instant::now(),
}
}
......@@ -157,14 +153,12 @@ impl ActiveSequences {
self.blocks.active_blocks()
}
#[cfg(test)]
pub(super) fn active_tokens(&self, decay_now: Instant) -> usize {
self.prefill.snapshot().active_tokens_at(decay_now)
}
/// Add a new request with optional prompt-token load accounting.
/// Returns block membership transitions plus any expired request IDs removed during cleanup.
#[allow(clippy::too_many_arguments)]
pub(super) fn add_request_with_prefill_tracking(
&mut self,
request_id: RequestId,
......@@ -299,6 +293,31 @@ impl ActiveSequences {
membership_delta
}
pub fn new_tokens(&self, isl: usize, cached_tokens: usize) -> usize {
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens}, returning 0"
);
0
})
}
pub fn potential_blocks_and_tokens(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
cached_tokens: usize,
decay_now: Instant,
) -> (usize, usize) {
self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
cached_tokens,
true,
decay_now,
)
}
/// Add an output block with a random hash and optional fractional decay weight.
///
/// This is used during generation to track output blocks as they are created.
......@@ -330,12 +349,11 @@ impl ActiveSequences {
acquire.became_present_on_worker.then_some(random_hash)
}
#[cfg(test)]
fn potential_blocks_and_tokens_with_prefill_tracking(
pub fn potential_blocks_and_tokens_with_prefill_tracking(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlap: u32,
cached_tokens: usize,
track_prefill_tokens: bool,
decay_now: Instant,
) -> (usize, usize) {
......@@ -346,7 +364,7 @@ impl ActiveSequences {
};
let active_tokens = self.active_tokens(decay_now);
let potential_tokens = if track_prefill_tokens {
added_prefill_tokens(self.block_size, isl, overlap) + active_tokens
self.new_tokens(isl, cached_tokens) + active_tokens
} else {
active_tokens
};
......
This diff is collapsed.
......@@ -230,6 +230,7 @@ impl KvBlockManagerConfigBuilder {
engine_endpoint: String,
output_endpoint: Option<String>,
engine_source: crate::block_manager::kv_consolidator::EventSource,
mode: crate::block_manager::kv_consolidator::KvEventConsolidationMode,
) -> Self {
let config = match engine_source {
crate::block_manager::kv_consolidator::EventSource::Vllm => {
......@@ -237,6 +238,7 @@ impl KvBlockManagerConfigBuilder {
crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_vllm(
engine_endpoint,
output_ep,
mode,
)
}
crate::block_manager::kv_consolidator::EventSource::Trtllm => {
......@@ -248,6 +250,7 @@ impl KvBlockManagerConfigBuilder {
crate::block_manager::kv_consolidator::KvEventConsolidatorConfig::new_trtllm(
engine_endpoint,
output_ep,
mode,
)
}
crate::block_manager::kv_consolidator::EventSource::Kvbm => {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment