Unverified Commit 6fab12be authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

test: add tests for lora aware kv cache events and routing (#6523)

parent 35128b14
...@@ -1911,7 +1911,10 @@ mod tests { ...@@ -1911,7 +1911,10 @@ mod tests {
use super::*; use super::*;
use crate::concurrent_radix_tree::ConcurrentRadixTree; use crate::concurrent_radix_tree::ConcurrentRadixTree;
use crate::nested_map::PositionalIndexer; use crate::nested_map::PositionalIndexer;
use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash, compute_seq_hash_for_block}; use crate::protocols::{
ExternalSequenceBlockHash, LocalBlockHash, compute_block_hash_for_seq,
compute_seq_hash_for_block,
};
use rstest::rstest; use rstest::rstest;
use rstest_reuse::{self, *}; use rstest_reuse::{self, *};
use std::time::Instant; use std::time::Instant;
...@@ -2631,6 +2634,296 @@ mod tests { ...@@ -2631,6 +2634,296 @@ mod tests {
); );
} }
// ============================================================================
// LoRA isolation tests
// ============================================================================
#[tokio::test]
#[apply(indexer_template)]
async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
// Same token sequence for both base model and LoRA adapter
let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
let base_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_hashes =
compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
// Hashes must differ despite identical tokens
assert_ne!(
base_hashes, lora_hashes,
"Base and LoRA hashes must differ for the same tokens"
);
let base_seq = compute_seq_hash_for_block(&base_hashes);
let lora_seq = compute_seq_hash_for_block(&lora_hashes);
// Store base-model blocks on worker 0
let base_event = RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: base_hashes
.iter()
.zip(base_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
index.apply_event(base_event).await;
// Store LoRA blocks on worker 1
let lora_event = RouterEvent {
worker_id: 1,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: lora_hashes
.iter()
.zip(lora_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
};
index.apply_event(lora_event).await;
// flush + settle time for thread-pool variants
index.flush().await;
tokio::time::sleep(Duration::from_millis(50)).await;
// Query with base-model hashes → only worker 0
let base_scores = index.find_matches(base_hashes.clone()).await.unwrap();
assert_eq!(
base_scores.scores.len(),
1,
"Only base-model worker should match"
);
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
// Query with LoRA hashes → only worker 1
let lora_scores = index.find_matches(lora_hashes.clone()).await.unwrap();
assert_eq!(lora_scores.scores.len(), 1, "Only LoRA worker should match");
assert_eq!(
*lora_scores
.scores
.get(&WorkerWithDpRank::new(1, 0))
.unwrap(),
3
);
}
/// Reproduces the "block_hash mismatch: sequence hashes should be uniform
/// across workers" warning seen when the same prompt is sent to both a base
/// model worker and a LoRA worker.
///
/// On main (without LoRA-aware hashing), both workers compute the same
/// LocalBlockHash for identical tokens. But vLLM's engine includes the
/// adapter in its rolling ExternalSequenceBlockHash, so the radix tree
/// sees conflicting sequence hashes at the same tree node.
///
/// With LoRA-aware hashing, compute_block_hash_for_seq produces distinct
/// LocalBlockHash values for different adapters, so the blocks land on
/// separate tree paths and no mismatch occurs.
#[tokio::test]
#[apply(indexer_template)]
async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
// With LoRA-aware hashing, base and adapter produce different LocalBlockHash
let base_local = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_local =
compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
assert_ne!(
base_local, lora_local,
"LoRA-aware hashing must produce different LocalBlockHash values"
);
// Simulate what vLLM does: same tokens, different rolling seq hashes
// because the engine accounts for the adapter internally.
let base_seq = compute_seq_hash_for_block(&base_local);
let lora_seq = compute_seq_hash_for_block(&lora_local);
// Worker 0: base model
index
.apply_event(RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: base_local
.iter()
.zip(base_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await;
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
// a separate tree path instead of colliding with worker 0's node.
index
.apply_event(RouterEvent {
worker_id: 1,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: lora_local
.iter()
.zip(lora_seq.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await;
index.flush().await;
tokio::time::sleep(Duration::from_millis(50)).await;
// Base query finds only worker 0
let base_scores = index.find_matches(base_local.clone()).await.unwrap();
assert_eq!(base_scores.scores.len(), 1);
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
// LoRA query finds only worker 1
let lora_scores = index.find_matches(lora_local.clone()).await.unwrap();
assert_eq!(lora_scores.scores.len(), 1);
assert_eq!(
*lora_scores
.scores
.get(&WorkerWithDpRank::new(1, 0))
.unwrap(),
3
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_different_lora_adapters_do_not_conflict(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
let tokens: Vec<u32> = (0..kv_block_size * 2).collect();
let hashes_a = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-a"));
let hashes_b = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-b"));
assert_ne!(
hashes_a, hashes_b,
"Different adapters must produce different hashes"
);
let seq_a = compute_seq_hash_for_block(&hashes_a);
let seq_b = compute_seq_hash_for_block(&hashes_b);
// Store adapter-a blocks on worker 0
index
.apply_event(RouterEvent {
worker_id: 0,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes_a
.iter()
.zip(seq_a.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await;
// Store adapter-b blocks on worker 1
index
.apply_event(RouterEvent {
worker_id: 1,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes_b
.iter()
.zip(seq_b.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
})
.await;
index.flush().await;
tokio::time::sleep(Duration::from_millis(50)).await;
// Query adapter-a → only worker 0
let scores_a = index.find_matches(hashes_a.clone()).await.unwrap();
assert_eq!(scores_a.scores.len(), 1);
assert!(scores_a.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
assert!(!scores_a.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
// Query adapter-b → only worker 1
let scores_b = index.find_matches(hashes_b.clone()).await.unwrap();
assert_eq!(scores_b.scores.len(), 1);
assert!(scores_b.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
assert!(!scores_b.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
// ============================================================================ // ============================================================================
// Long sequence tests - especially important for NestedMap/PositionalIndexer // Long sequence tests - especially important for NestedMap/PositionalIndexer
// ============================================================================ // ============================================================================
......
...@@ -889,6 +889,62 @@ mod tests { ...@@ -889,6 +889,62 @@ mod tests {
assert_ne!(hash1, hash3); assert_ne!(hash1, hash3);
} }
#[test]
fn test_lora_name_round_trip_through_tracker() {
let mut tracker = CacheStatusTracker::new();
let should_publish = tracker.handle_store(
"hash_lora".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
Some("my-adapter".to_string()),
Some(StorageTier::Device),
None,
);
assert!(should_publish);
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
match &events[0] {
ConsolidatedEvent::Store {
lora_name,
token_ids,
..
} => {
assert_eq!(lora_name.as_deref(), Some("my-adapter"));
assert_eq!(token_ids, &[1, 2, 3, 4]);
}
other => panic!("expected Store event, got: {:?}", other),
}
}
#[test]
fn test_lora_name_none_for_base_model() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"hash_base".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
match &events[0] {
ConsolidatedEvent::Store { lora_name, .. } => {
assert!(lora_name.is_none());
}
other => panic!("expected Store event, got: {:?}", other),
}
}
#[test] #[test]
fn test_compute_sequence_hash_deterministic() { fn test_compute_sequence_hash_deterministic() {
let block_hash1 = compute_local_block_hash(&[1, 2, 3, 4]); let block_hash1 = compute_local_block_hash(&[1, 2, 3, 4]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment