Unverified Commit 75fea787 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: make KV cache events and routing LoRA-aware (#6517)

parent 432fae67
...@@ -141,6 +141,7 @@ impl KvRouterConfig { ...@@ -141,6 +141,7 @@ impl KvRouterConfig {
tokens: &[u32], tokens: &[u32],
block_size: u32, block_size: u32,
config_override: Option<&RouterConfigOverride>, config_override: Option<&RouterConfigOverride>,
lora_name: Option<&str>,
) -> Option<Vec<u64>> { ) -> Option<Vec<u64>> {
if !self.router_track_active_blocks { if !self.router_track_active_blocks {
return None; return None;
...@@ -151,17 +152,14 @@ impl KvRouterConfig { ...@@ -151,17 +152,14 @@ impl KvRouterConfig {
return Some(Vec::new()); return Some(Vec::new());
} }
// Use override if provided, otherwise use default config
let assume_kv_reuse = config_override let assume_kv_reuse = config_override
.and_then(|cfg| cfg.assume_kv_reuse) .and_then(|cfg| cfg.assume_kv_reuse)
.unwrap_or(self.router_assume_kv_reuse); .unwrap_or(self.router_assume_kv_reuse);
if assume_kv_reuse { if assume_kv_reuse {
// Compute actual block hashes and sequence hashes let block_hashes = compute_block_hash_for_seq(tokens, block_size, None, lora_name);
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
Some(compute_seq_hash_for_block(&block_hashes)) Some(compute_seq_hash_for_block(&block_hashes))
} else { } else {
// Generate random hashes (no KV reuse assumed)
let mut rng = rand::rng(); let mut rng = rand::rng();
Some((0..num_blocks).map(|_| rng.random::<u64>()).collect()) Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
} }
......
...@@ -31,6 +31,7 @@ pub enum IndexerQueryRequest { ...@@ -31,6 +31,7 @@ pub enum IndexerQueryRequest {
FindMatchesTokens { FindMatchesTokens {
tokens: Vec<u32>, tokens: Vec<u32>,
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>, block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
lora_name: Option<String>,
}, },
DumpTree, DumpTree,
} }
...@@ -90,7 +91,13 @@ impl ...@@ -90,7 +91,13 @@ impl
IndexerQueryRequest::FindMatchesTokens { IndexerQueryRequest::FindMatchesTokens {
tokens, tokens,
block_mm_infos, block_mm_infos,
} => compute_block_hash_for_seq(&tokens, self.block_size, block_mm_infos.as_deref()), lora_name,
} => compute_block_hash_for_seq(
&tokens,
self.block_size,
block_mm_infos.as_deref(),
lora_name.as_deref(),
),
IndexerQueryRequest::DumpTree => unreachable!(), IndexerQueryRequest::DumpTree => unreachable!(),
}; };
......
...@@ -574,9 +574,9 @@ fn convert_event( ...@@ -574,9 +574,9 @@ fn convert_event(
parent_block_hash, parent_block_hash,
token_ids, token_ids,
block_size, block_size,
lora_id, lora_name,
block_mm_infos, block_mm_infos,
.. medium: _,
} => { } => {
// Reject self-referencing blocks: all block hashes (including parent) must be unique. // Reject self-referencing blocks: all block hashes (including parent) must be unique.
{ {
...@@ -614,7 +614,7 @@ fn convert_event( ...@@ -614,7 +614,7 @@ fn convert_event(
&token_ids, &token_ids,
&num_block_tokens, &num_block_tokens,
&block_hashes_u64, &block_hashes_u64,
lora_id.unwrap_or(0), lora_name.as_deref(),
warning_count, warning_count,
block_mm_infos.as_deref(), block_mm_infos.as_deref(),
), ),
...@@ -648,13 +648,16 @@ pub fn create_stored_block_from_parts( ...@@ -648,13 +648,16 @@ pub fn create_stored_block_from_parts(
kv_block_size: u32, kv_block_size: u32,
block_hash: u64, block_hash: u64,
token_ids: &[u32], token_ids: &[u32],
_lora_id: u64, lora_name: Option<&str>,
mm_extra_info: Option<BlockExtraInfo>, mm_extra_info: Option<BlockExtraInfo>,
) -> KvCacheStoredBlockData { ) -> KvCacheStoredBlockData {
// Compute tokens_hash including MM info if present
let block_mm_infos = mm_extra_info.as_ref().map(|info| vec![Some(info.clone())]); let block_mm_infos = mm_extra_info.as_ref().map(|info| vec![Some(info.clone())]);
let tokens_hash = let tokens_hash = compute_block_hash_for_seq(
compute_block_hash_for_seq(token_ids, kv_block_size, block_mm_infos.as_deref())[0]; token_ids,
kv_block_size,
block_mm_infos.as_deref(),
lora_name,
)[0];
tracing::trace!( tracing::trace!(
"Creating stored block: external_block_hash={}, tokens_hash={}, token_ids={:?}, kv_block_size={}, mm_extra_info={:?}", "Creating stored block: external_block_hash={}, tokens_hash={}, token_ids={:?}, kv_block_size={}, mm_extra_info={:?}",
...@@ -676,7 +679,7 @@ pub fn create_stored_blocks( ...@@ -676,7 +679,7 @@ pub fn create_stored_blocks(
token_ids: &[u32], token_ids: &[u32],
num_block_tokens: &[u64], num_block_tokens: &[u64],
block_hashes: &[u64], block_hashes: &[u64],
lora_id: u64, lora_name: Option<&str>,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>, block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
) -> Vec<KvCacheStoredBlockData> { ) -> Vec<KvCacheStoredBlockData> {
...@@ -706,7 +709,7 @@ pub fn create_stored_blocks( ...@@ -706,7 +709,7 @@ pub fn create_stored_blocks(
kv_block_size, kv_block_size,
*block_hash_it, *block_hash_it,
tokens, tokens,
lora_id, lora_name,
mm_extra_info, mm_extra_info,
)); ));
token_offset += *num_tokens_it as usize; token_offset += *num_tokens_it as usize;
...@@ -768,11 +771,9 @@ enum RawKvEvent { ...@@ -768,11 +771,9 @@ enum RawKvEvent {
parent_block_hash: Option<BlockHashValue>, parent_block_hash: Option<BlockHashValue>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
block_size: usize, block_size: usize,
/// Deprecated in vLLM 0.14.0: use `lora_name` instead
lora_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
medium: Option<String>, medium: Option<String>,
/// LoRA adapter name (added in vLLM 0.14.0, replaces lora_id) /// LoRA adapter name for adapter-aware block hashing
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
lora_name: Option<String>, lora_name: Option<String>,
/// Multimodal extra info for each block (length should match block_hashes) /// Multimodal extra info for each block (length should match block_hashes)
...@@ -873,7 +874,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -873,7 +874,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let mut parent_block_hash: Option<Option<BlockHashValue>> = None; let mut parent_block_hash: Option<Option<BlockHashValue>> = None;
let mut token_ids: Option<Vec<u32>> = None; let mut token_ids: Option<Vec<u32>> = None;
let mut block_size: Option<usize> = None; let mut block_size: Option<usize> = None;
let mut lora_id: Option<Option<u64>> = None;
let mut medium: Option<Option<String>> = None; let mut medium: Option<Option<String>> = None;
let mut lora_name: Option<Option<String>> = None; let mut lora_name: Option<Option<String>> = None;
let mut extra_keys: Option<Option<Vec<Option<Vec<String>>>>> = None; let mut extra_keys: Option<Option<Vec<Option<Vec<String>>>>> = None;
...@@ -896,9 +896,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -896,9 +896,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
"block_size" => { "block_size" => {
block_size = Some(map.next_value()?); block_size = Some(map.next_value()?);
} }
"lora_id" => {
lora_id = Some(map.next_value()?);
}
"medium" => { "medium" => {
medium = Some(map.next_value()?); medium = Some(map.next_value()?);
} }
...@@ -932,7 +929,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -932,7 +929,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
parent_block_hash: parent_block_hash.unwrap_or(None), parent_block_hash: parent_block_hash.unwrap_or(None),
token_ids, token_ids,
block_size, block_size,
lora_id: lora_id.unwrap_or(None),
medium: medium.unwrap_or(None), medium: medium.unwrap_or(None),
lora_name: lora_name.unwrap_or(None), lora_name: lora_name.unwrap_or(None),
block_mm_infos, block_mm_infos,
...@@ -979,7 +975,8 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -979,7 +975,8 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let block_size: usize = seq let block_size: usize = seq
.next_element()? .next_element()?
.ok_or_else(|| de::Error::invalid_length(4, &"missing block_size"))?; .ok_or_else(|| de::Error::invalid_length(4, &"missing block_size"))?;
let lora_id: Option<u64> = seq.next_element()?.unwrap_or(None); // Position 5 was lora_id in older formats; consume and discard for compat
let _lora_id: Option<u64> = seq.next_element()?.unwrap_or(None);
let medium: Option<String> = seq.next_element()?.unwrap_or(None); let medium: Option<String> = seq.next_element()?.unwrap_or(None);
let lora_name: Option<String> = seq.next_element()?.unwrap_or(None); let lora_name: Option<String> = seq.next_element()?.unwrap_or(None);
let extra_keys: Option<Vec<Option<Vec<String>>>> = let extra_keys: Option<Vec<Option<Vec<String>>>> =
...@@ -997,7 +994,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -997,7 +994,6 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
parent_block_hash, parent_block_hash,
token_ids, token_ids,
block_size, block_size,
lora_id,
medium, medium,
lora_name, lora_name,
block_mm_infos, block_mm_infos,
...@@ -1172,10 +1168,11 @@ mod test_event_processing { ...@@ -1172,10 +1168,11 @@ mod test_event_processing {
let token_ids = vec![10, 20, 30, 40]; let token_ids = vec![10, 20, 30, 40];
let blk_hash = 0xdead_beef; let blk_hash = 0xdead_beef;
let stored = create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, 0, None); let stored =
create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, None, None);
assert_eq!(stored.block_hash.0, blk_hash); assert_eq!(stored.block_hash.0, blk_hash);
let expected_hash = compute_block_hash_for_seq(&token_ids, 4, None)[0]; let expected_hash = compute_block_hash_for_seq(&token_ids, 4, None, None)[0];
assert_eq!(stored.tokens_hash, expected_hash); assert_eq!(stored.tokens_hash, expected_hash);
assert!(stored.mm_extra_info.is_none()); assert!(stored.mm_extra_info.is_none());
} }
...@@ -1196,7 +1193,7 @@ mod test_event_processing { ...@@ -1196,7 +1193,7 @@ mod test_event_processing {
&token_ids, &token_ids,
&num_block_tokens, &num_block_tokens,
&block_hashes, &block_hashes,
/*lora_id=*/ 0, None,
&Arc::new(AtomicU32::new(0)), &Arc::new(AtomicU32::new(0)),
None, None,
); );
...@@ -1209,7 +1206,6 @@ mod test_event_processing { ...@@ -1209,7 +1206,6 @@ mod test_event_processing {
#[test] #[test]
fn test_create_stored_blocks_wrong_size_triggers_warning() { fn test_create_stored_blocks_wrong_size_triggers_warning() {
let kv_block_size = 4; let kv_block_size = 4;
// second block is the wrong size
let token_ids = vec![1, 2, 3, 4, 5, 6, 7]; let token_ids = vec![1, 2, 3, 4, 5, 6, 7];
let num_block_tokens = vec![4_u64, 3_u64]; let num_block_tokens = vec![4_u64, 3_u64];
let block_hashes = vec![111_u64, 222_u64]; let block_hashes = vec![111_u64, 222_u64];
...@@ -1220,7 +1216,7 @@ mod test_event_processing { ...@@ -1220,7 +1216,7 @@ mod test_event_processing {
&token_ids, &token_ids,
&num_block_tokens, &num_block_tokens,
&block_hashes, &block_hashes,
/*lora_id=*/ 0, None,
&warning_count, &warning_count,
None, None,
); );
...@@ -1241,7 +1237,6 @@ mod test_event_processing { ...@@ -1241,7 +1237,6 @@ mod test_event_processing {
parent_block_hash: Some(BlockHashValue::Unsigned(99)), parent_block_hash: Some(BlockHashValue::Unsigned(99)),
token_ids: vec![1, 2, 3, 4, 5, 6, 7, 8], token_ids: vec![1, 2, 3, 4, 5, 6, 7, 8],
block_size: 4, block_size: 4,
lora_id: Some(0),
medium: None, medium: None,
lora_name: None, lora_name: None,
block_mm_infos: None, block_mm_infos: None,
...@@ -1251,6 +1246,146 @@ mod test_event_processing { ...@@ -1251,6 +1246,146 @@ mod test_event_processing {
assert!(matches!(out.data, KvCacheEventData::Stored(_))); assert!(matches!(out.data, KvCacheEventData::Stored(_)));
} }
#[test]
fn test_convert_event_with_lora_name() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4];
let base_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let lora_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: Some("my-lora".to_string()),
block_mm_infos: None,
};
let wc = Arc::new(AtomicU32::new(0));
let base_out = convert_event(base_evt, 1, kv_block_size, 0, &wc);
let lora_out = convert_event(lora_evt, 2, kv_block_size, 0, &wc);
let base_hash = match &base_out.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let lora_hash = match &lora_out.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
assert_ne!(
base_hash, lora_hash,
"LoRA blocks must produce distinct tokens_hash"
);
}
#[test]
fn test_convert_event_lora_name_none_is_base_model() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4];
let wc = Arc::new(AtomicU32::new(0));
let evt1 = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let evt2 = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let out1 = convert_event(evt1, 1, kv_block_size, 0, &wc);
let out2 = convert_event(evt2, 2, kv_block_size, 0, &wc);
let hash1 = match &out1.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let hash2 = match &out2.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
assert_eq!(
hash1, hash2,
"Two base-model events with same tokens should produce same hash"
);
}
#[test]
fn test_backward_compat_deserialize_map_with_lora_id_no_lora_name() {
#[derive(serde::Serialize)]
struct OldFormatEvent {
#[serde(rename = "type")]
event_type: &'static str,
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<u64>,
}
let payload = rmps::to_vec(&OldFormatEvent {
event_type: "BlockStored",
block_hashes: vec![42],
parent_block_hash: None,
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_id: Some(5),
})
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { lora_name, .. } = event else {
panic!("expected BlockStored");
};
assert!(
lora_name.is_none(),
"old-format payloads with lora_id but no lora_name should deserialize with lora_name=None"
);
}
#[test]
fn test_backward_compat_deserialize_seq_with_lora_id_no_lora_name() {
let payload = rmps::to_vec(&(
"BlockStored",
vec![42_u64],
None::<u64>,
vec![1_u32, 2, 3, 4],
4_usize,
Some(5_u64), // lora_id at position 5
// no medium, no lora_name — simulating an old producer
))
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { lora_name, .. } = event else {
panic!("expected BlockStored");
};
assert!(
lora_name.is_none(),
"old seq-format payloads with lora_id should deserialize with lora_name=None"
);
}
#[test] #[test]
fn test_convert_event_block_removed() { fn test_convert_event_block_removed() {
let kv_block_size = 4; let kv_block_size = 4;
...@@ -1795,7 +1930,6 @@ mod tests_startup_helpers { ...@@ -1795,7 +1930,6 @@ mod tests_startup_helpers {
parent_block_hash: None, parent_block_hash: None,
token_ids: vec![0, 1, 2, 3], token_ids: vec![0, 1, 2, 3],
block_size: 4, block_size: 4,
lora_id: None,
medium: None, medium: None,
lora_name: None, lora_name: None,
block_mm_infos: None, block_mm_infos: None,
......
...@@ -262,7 +262,7 @@ impl KvPushRouter { ...@@ -262,7 +262,7 @@ impl KvPushRouter {
let worker = WorkerWithDpRank::new(id, dp_rank); let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self let overlap_blocks = self
.chooser .chooser
.get_overlap_blocks(routing_token_ids, worker) .get_overlap_blocks(routing_token_ids, worker, lora_name.as_deref())
.await?; .await?;
if !is_query_only { if !is_query_only {
......
...@@ -121,7 +121,11 @@ def set_ucx_tls_no_mm(): ...@@ -121,7 +121,11 @@ def set_ucx_tls_no_mm():
# (uct_mem.c:482: mem.memh != UCT_MEM_HANDLE_NULL) when two workers # (uct_mem.c:482: mem.memh != UCT_MEM_HANDLE_NULL) when two workers
# start on the same node (maybe a shared-memory segment collision/limits). # start on the same node (maybe a shared-memory segment collision/limits).
# - Mitigation: disable UCX "mm" shared-memory transport globally for tests # - Mitigation: disable UCX "mm" shared-memory transport globally for tests
mp.setenv("UCX_TLS", "^mm") #
# Also exclude gdr_copy transport to prevent GDRCopy driver initialization
# failures (driverInitFileInfo result=11) that can abort the process when
# the gdrdrv kernel module is not loaded.
mp.setenv("UCX_TLS", "^mm,gdr_copy")
yield yield
mp.undo() mp.undo()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment