Unverified Commit 75fea787 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: make KV cache events and routing LoRA-aware (#6517)

parent 432fae67
......@@ -141,6 +141,7 @@ impl KvRouterConfig {
tokens: &[u32],
block_size: u32,
config_override: Option<&RouterConfigOverride>,
lora_name: Option<&str>,
) -> Option<Vec<u64>> {
if !self.router_track_active_blocks {
return None;
......@@ -151,17 +152,14 @@ impl KvRouterConfig {
return Some(Vec::new());
}
// Use override if provided, otherwise use default config
let assume_kv_reuse = config_override
.and_then(|cfg| cfg.assume_kv_reuse)
.unwrap_or(self.router_assume_kv_reuse);
if assume_kv_reuse {
// Compute actual block hashes and sequence hashes
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None, lora_name);
Some(compute_seq_hash_for_block(&block_hashes))
} else {
// Generate random hashes (no KV reuse assumed)
let mut rng = rand::rng();
Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
}
......
......@@ -31,6 +31,7 @@ pub enum IndexerQueryRequest {
FindMatchesTokens {
tokens: Vec<u32>,
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
lora_name: Option<String>,
},
DumpTree,
}
......@@ -90,7 +91,13 @@ impl
IndexerQueryRequest::FindMatchesTokens {
tokens,
block_mm_infos,
} => compute_block_hash_for_seq(&tokens, self.block_size, block_mm_infos.as_deref()),
lora_name,
} => compute_block_hash_for_seq(
&tokens,
self.block_size,
block_mm_infos.as_deref(),
lora_name.as_deref(),
),
IndexerQueryRequest::DumpTree => unreachable!(),
};
......
......@@ -574,9 +574,9 @@ fn convert_event(
parent_block_hash,
token_ids,
block_size,
lora_id,
lora_name,
block_mm_infos,
..
medium: _,
} => {
// Reject self-referencing blocks: all block hashes (including parent) must be unique.
{
......@@ -614,7 +614,7 @@ fn convert_event(
&token_ids,
&num_block_tokens,
&block_hashes_u64,
lora_id.unwrap_or(0),
lora_name.as_deref(),
warning_count,
block_mm_infos.as_deref(),
),
......@@ -648,13 +648,16 @@ pub fn create_stored_block_from_parts(
kv_block_size: u32,
block_hash: u64,
token_ids: &[u32],
_lora_id: u64,
lora_name: Option<&str>,
mm_extra_info: Option<BlockExtraInfo>,
) -> KvCacheStoredBlockData {
// Compute tokens_hash including MM info if present
let block_mm_infos = mm_extra_info.as_ref().map(|info| vec![Some(info.clone())]);
let tokens_hash =
compute_block_hash_for_seq(token_ids, kv_block_size, block_mm_infos.as_deref())[0];
let tokens_hash = compute_block_hash_for_seq(
token_ids,
kv_block_size,
block_mm_infos.as_deref(),
lora_name,
)[0];
tracing::trace!(
"Creating stored block: external_block_hash={}, tokens_hash={}, token_ids={:?}, kv_block_size={}, mm_extra_info={:?}",
......@@ -676,7 +679,7 @@ pub fn create_stored_blocks(
token_ids: &[u32],
num_block_tokens: &[u64],
block_hashes: &[u64],
lora_id: u64,
lora_name: Option<&str>,
warning_count: &Arc<AtomicU32>,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
) -> Vec<KvCacheStoredBlockData> {
......@@ -706,7 +709,7 @@ pub fn create_stored_blocks(
kv_block_size,
*block_hash_it,
tokens,
lora_id,
lora_name,
mm_extra_info,
));
token_offset += *num_tokens_it as usize;
......@@ -768,11 +771,9 @@ enum RawKvEvent {
parent_block_hash: Option<BlockHashValue>,
token_ids: Vec<u32>,
block_size: usize,
/// Deprecated in vLLM 0.14.0: use `lora_name` instead
lora_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
medium: Option<String>,
/// LoRA adapter name (added in vLLM 0.14.0, replaces lora_id)
/// LoRA adapter name for adapter-aware block hashing
#[serde(default, skip_serializing_if = "Option::is_none")]
lora_name: Option<String>,
/// Multimodal extra info for each block (length should match block_hashes)
......@@ -873,7 +874,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let mut parent_block_hash: Option<Option<BlockHashValue>> = None;
let mut token_ids: Option<Vec<u32>> = None;
let mut block_size: Option<usize> = None;
let mut lora_id: Option<Option<u64>> = None;
let mut medium: Option<Option<String>> = None;
let mut lora_name: Option<Option<String>> = None;
let mut extra_keys: Option<Option<Vec<Option<Vec<String>>>>> = None;
......@@ -896,9 +896,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
"block_size" => {
block_size = Some(map.next_value()?);
}
"lora_id" => {
lora_id = Some(map.next_value()?);
}
"medium" => {
medium = Some(map.next_value()?);
}
......@@ -932,7 +929,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
parent_block_hash: parent_block_hash.unwrap_or(None),
token_ids,
block_size,
lora_id: lora_id.unwrap_or(None),
medium: medium.unwrap_or(None),
lora_name: lora_name.unwrap_or(None),
block_mm_infos,
......@@ -979,7 +975,8 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let block_size: usize = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(4, &"missing block_size"))?;
let lora_id: Option<u64> = seq.next_element()?.unwrap_or(None);
// Position 5 was lora_id in older formats; consume and discard for compat
let _lora_id: Option<u64> = seq.next_element()?.unwrap_or(None);
let medium: Option<String> = seq.next_element()?.unwrap_or(None);
let lora_name: Option<String> = seq.next_element()?.unwrap_or(None);
let extra_keys: Option<Vec<Option<Vec<String>>>> =
......@@ -997,7 +994,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
parent_block_hash,
token_ids,
block_size,
lora_id,
medium,
lora_name,
block_mm_infos,
......@@ -1172,10 +1168,11 @@ mod test_event_processing {
let token_ids = vec![10, 20, 30, 40];
let blk_hash = 0xdead_beef;
let stored = create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, 0, None);
let stored =
create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, None, None);
assert_eq!(stored.block_hash.0, blk_hash);
let expected_hash = compute_block_hash_for_seq(&token_ids, 4, None)[0];
let expected_hash = compute_block_hash_for_seq(&token_ids, 4, None, None)[0];
assert_eq!(stored.tokens_hash, expected_hash);
assert!(stored.mm_extra_info.is_none());
}
......@@ -1196,7 +1193,7 @@ mod test_event_processing {
&token_ids,
&num_block_tokens,
&block_hashes,
/*lora_id=*/ 0,
None,
&Arc::new(AtomicU32::new(0)),
None,
);
......@@ -1209,7 +1206,6 @@ mod test_event_processing {
#[test]
fn test_create_stored_blocks_wrong_size_triggers_warning() {
let kv_block_size = 4;
// second block is the wrong size
let token_ids = vec![1, 2, 3, 4, 5, 6, 7];
let num_block_tokens = vec![4_u64, 3_u64];
let block_hashes = vec![111_u64, 222_u64];
......@@ -1220,7 +1216,7 @@ mod test_event_processing {
&token_ids,
&num_block_tokens,
&block_hashes,
/*lora_id=*/ 0,
None,
&warning_count,
None,
);
......@@ -1241,7 +1237,6 @@ mod test_event_processing {
parent_block_hash: Some(BlockHashValue::Unsigned(99)),
token_ids: vec![1, 2, 3, 4, 5, 6, 7, 8],
block_size: 4,
lora_id: Some(0),
medium: None,
lora_name: None,
block_mm_infos: None,
......@@ -1251,6 +1246,146 @@ mod test_event_processing {
assert!(matches!(out.data, KvCacheEventData::Stored(_)));
}
#[test]
fn test_convert_event_with_lora_name() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4];
let base_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let lora_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: Some("my-lora".to_string()),
block_mm_infos: None,
};
let wc = Arc::new(AtomicU32::new(0));
let base_out = convert_event(base_evt, 1, kv_block_size, 0, &wc);
let lora_out = convert_event(lora_evt, 2, kv_block_size, 0, &wc);
let base_hash = match &base_out.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let lora_hash = match &lora_out.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
assert_ne!(
base_hash, lora_hash,
"LoRA blocks must produce distinct tokens_hash"
);
}
#[test]
fn test_convert_event_lora_name_none_is_base_model() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4];
let wc = Arc::new(AtomicU32::new(0));
let evt1 = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let evt2 = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let out1 = convert_event(evt1, 1, kv_block_size, 0, &wc);
let out2 = convert_event(evt2, 2, kv_block_size, 0, &wc);
let hash1 = match &out1.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let hash2 = match &out2.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
assert_eq!(
hash1, hash2,
"Two base-model events with same tokens should produce same hash"
);
}
#[test]
fn test_backward_compat_deserialize_map_with_lora_id_no_lora_name() {
#[derive(serde::Serialize)]
struct OldFormatEvent {
#[serde(rename = "type")]
event_type: &'static str,
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<u64>,
}
let payload = rmps::to_vec(&OldFormatEvent {
event_type: "BlockStored",
block_hashes: vec![42],
parent_block_hash: None,
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_id: Some(5),
})
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { lora_name, .. } = event else {
panic!("expected BlockStored");
};
assert!(
lora_name.is_none(),
"old-format payloads with lora_id but no lora_name should deserialize with lora_name=None"
);
}
#[test]
fn test_backward_compat_deserialize_seq_with_lora_id_no_lora_name() {
let payload = rmps::to_vec(&(
"BlockStored",
vec![42_u64],
None::<u64>,
vec![1_u32, 2, 3, 4],
4_usize,
Some(5_u64), // lora_id at position 5
// no medium, no lora_name — simulating an old producer
))
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { lora_name, .. } = event else {
panic!("expected BlockStored");
};
assert!(
lora_name.is_none(),
"old seq-format payloads with lora_id should deserialize with lora_name=None"
);
}
#[test]
fn test_convert_event_block_removed() {
let kv_block_size = 4;
......@@ -1795,7 +1930,6 @@ mod tests_startup_helpers {
parent_block_hash: None,
token_ids: vec![0, 1, 2, 3],
block_size: 4,
lora_id: None,
medium: None,
lora_name: None,
block_mm_infos: None,
......
......@@ -262,7 +262,7 @@ impl KvPushRouter {
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self
.chooser
.get_overlap_blocks(routing_token_ids, worker)
.get_overlap_blocks(routing_token_ids, worker, lora_name.as_deref())
.await?;
if !is_query_only {
......
......@@ -121,7 +121,11 @@ def set_ucx_tls_no_mm():
# (uct_mem.c:482: mem.memh != UCT_MEM_HANDLE_NULL) when two workers
# start on the same node (maybe a shared-memory segment collision/limits).
# - Mitigation: disable UCX "mm" shared-memory transport globally for tests
mp.setenv("UCX_TLS", "^mm")
#
# Also exclude gdr_copy transport to prevent GDRCopy driver initialization
# failures (driverInitFileInfo result=11) that can abort the process when
# the gdrdrv kernel module is not loaded.
mp.setenv("UCX_TLS", "^mm,gdr_copy")
yield
mp.undo()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment