Unverified Commit cdb40ec6 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): allow unit block size in slot tracking (#8395)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent e71f1d2b
...@@ -132,6 +132,8 @@ impl KvIndexer { ...@@ -132,6 +132,8 @@ impl KvIndexer {
metrics: Arc<KvIndexerMetrics>, metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>, prune_config: Option<PruneConfig>,
) -> Self { ) -> Self {
super::warn_on_unit_block_size("single", kv_block_size);
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(16384); let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(16384);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128); let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16); let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
......
...@@ -31,6 +31,16 @@ ...@@ -31,6 +31,16 @@
//! //!
//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance. //! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.
fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) {
if kv_block_size == 1 {
tracing::warn!(
indexer_type,
kv_block_size,
"block_size=1 is supported for KV indexers, but consider avoiding it because KV events may saturate network bandwidth",
);
}
}
mod kv_indexer; mod kv_indexer;
mod local; mod local;
mod metrics; mod metrics;
......
...@@ -99,6 +99,7 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> { ...@@ -99,6 +99,7 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
metrics: Option<Arc<KvIndexerMetrics>>, metrics: Option<Arc<KvIndexerMetrics>>,
) -> Self { ) -> Self {
assert!(num_workers > 0, "Number of workers must be greater than 0"); assert!(num_workers > 0, "Number of workers must be greater than 0");
super::warn_on_unit_block_size("thread_pool", kv_block_size);
let backend = Arc::new(backend); let backend = Arc::new(backend);
let mut worker_event_senders = Vec::new(); let mut worker_event_senders = Vec::new();
......
...@@ -136,7 +136,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -136,7 +136,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
router_id: u64, router_id: u64,
worker_type: &'static str, worker_type: &'static str,
) -> Self { ) -> Self {
assert!(block_size > 1, "block_size must be greater than 1"); assert!(block_size > 0, "block_size must be greater than 0");
let (remote_state_updates, _) = watch::channel(()); let (remote_state_updates, _) = watch::channel(());
let workers = WorkerTable::new(block_size, &dp_range); let workers = WorkerTable::new(block_size, &dp_range);
let prompt_registry = PromptRegistry::new(workers.workers()); let prompt_registry = PromptRegistry::new(workers.workers());
...@@ -975,6 +975,19 @@ mod tests { ...@@ -975,6 +975,19 @@ mod tests {
) )
} }
fn make_multi_sequences_with_block_size(
block_size: usize,
) -> ActiveSequencesMultiWorker<NoopSequencePublisher> {
ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
block_size,
HashMap::from([(1_u64, (0_u32, 1_u32)), (2_u64, (0_u32, 1_u32))]),
false,
0,
"test",
)
}
fn naive_potential_loads( fn naive_potential_loads(
sequences: &ActiveSequencesMultiWorker<NoopSequencePublisher>, sequences: &ActiveSequencesMultiWorker<NoopSequencePublisher>,
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
...@@ -1013,9 +1026,17 @@ mod tests { ...@@ -1013,9 +1026,17 @@ mod tests {
} }
fn seq_hashes_for_tokens(tokens: &[u32], lora_name: Option<&str>) -> Vec<SequenceHash> { fn seq_hashes_for_tokens(tokens: &[u32], lora_name: Option<&str>) -> Vec<SequenceHash> {
seq_hashes_for_tokens_with_block_size(tokens, 4, lora_name)
}
fn seq_hashes_for_tokens_with_block_size(
tokens: &[u32],
block_size: u32,
lora_name: Option<&str>,
) -> Vec<SequenceHash> {
let block_hashes = compute_block_hash_for_seq( let block_hashes = compute_block_hash_for_seq(
tokens, tokens,
4, block_size,
BlockHashOptions { BlockHashOptions {
lora_name, lora_name,
..Default::default() ..Default::default()
...@@ -1208,6 +1229,88 @@ mod tests { ...@@ -1208,6 +1229,88 @@ mod tests {
); );
} }
#[test]
fn unit_block_size_repeated_tokens_preserve_membership_and_trim() {
let sequences = make_multi_sequences_with_block_size(1);
let worker_a = WorkerWithDpRank::new(1, 0);
let worker_b = WorkerWithDpRank::new(2, 0);
let decay_now = Instant::now();
let prompt_a = seq_hashes_for_tokens_with_block_size(&[7_u32, 7, 7], 1, None);
let prompt_b = seq_hashes_for_tokens_with_block_size(&[7_u32, 7, 8], 1, None);
sequences
.add_request(
SequenceRequest {
request_id: "req-a".to_string(),
token_sequence: Some(prompt_a.clone()),
track_prefill_tokens: false,
expected_output_tokens: None,
prefill_load_hint: None,
worker: worker_a,
lora_name: None,
},
decay_now,
)
.unwrap();
sequences
.add_request(
SequenceRequest {
request_id: "req-b".to_string(),
token_sequence: Some(prompt_b.clone()),
track_prefill_tokens: false,
expected_output_tokens: None,
prefill_load_hint: None,
worker: worker_b,
lora_name: None,
},
decay_now,
)
.unwrap();
let expected = naive_potential_loads(
&sequences,
Some(&prompt_b),
3,
&OverlapScores::default(),
false,
decay_now,
);
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&prompt_b),
3,
OverlapScores::default(),
false,
decay_now,
);
assert_eq!(actual, expected);
assert_eq!(actual.0.get(&worker_a).copied(), Some(4));
assert_eq!(actual.0.get(&worker_b).copied(), Some(3));
sequences.free(&"req-b".to_string(), decay_now).unwrap();
let expected_after_free = naive_potential_loads(
&sequences,
Some(&prompt_b),
3,
&OverlapScores::default(),
false,
decay_now,
);
let actual_after_free = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&prompt_b),
3,
OverlapScores::default(),
false,
decay_now,
);
assert_eq!(actual_after_free, expected_after_free);
assert_eq!(actual_after_free.0.get(&worker_a).copied(), Some(4));
assert_eq!(actual_after_free.0.get(&worker_b).copied(), Some(3));
sequences.free(&"req-a".to_string(), decay_now).unwrap();
sequences.assert_completely_drained(decay_now);
}
#[tokio::test(start_paused = true)] #[tokio::test(start_paused = true)]
async fn force_expiry_clears_block_membership_index() { async fn force_expiry_clears_block_membership_index() {
let sequences = make_multi_sequences(); let sequences = make_multi_sequences();
......
...@@ -117,7 +117,7 @@ pub struct ActiveSequences { ...@@ -117,7 +117,7 @@ pub struct ActiveSequences {
impl ActiveSequences { impl ActiveSequences {
/// Create a new SharedSequenceManager instance /// Create a new SharedSequenceManager instance
pub(super) fn new(block_size: usize) -> Self { pub(super) fn new(block_size: usize) -> Self {
assert!(block_size > 1, "block_size must be greater than 1"); assert!(block_size > 0, "block_size must be greater than 0");
Self { Self {
requests: HashMap::new(), requests: HashMap::new(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment