"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "38364a7e32c923cfbd71f4d749fb2791c1741bd4"
Unverified Commit 01002df7 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): filter non-main ZMQ KV event groups (cherry-pick) (#8705)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 7d4572f9
......@@ -46,6 +46,10 @@ pytestmark = [
]
def _has_group_idx(event_cls):
return "group_idx" in event_cls.__struct_fields__
class TestVllmKvEventsApi:
"""Test vLLM KV events API compatibility."""
......@@ -61,10 +65,11 @@ class TestVllmKvEventsApi:
6. medium
7. lora_name (added in vLLM 0.14.0)
8. extra_keys (added in vLLM 0.17.0)
9. group_idx (added for hybrid KV cache groups; optional for older vLLM)
If vLLM adds/removes/reorders fields, this test will fail.
"""
expected_fields = (
expected_fields = [
"block_hashes",
"parent_block_hash",
"token_ids",
......@@ -73,7 +78,10 @@ class TestVllmKvEventsApi:
"medium",
"lora_name",
"extra_keys",
)
]
if _has_group_idx(BlockStored):
expected_fields.append("group_idx")
expected_fields = tuple(expected_fields)
actual_fields = BlockStored.__struct_fields__
assert actual_fields == expected_fields, (
......@@ -88,10 +96,13 @@ class TestVllmKvEventsApi:
def test_block_removed_fields(self):
"""Verify BlockRemoved has expected fields in expected order."""
expected_fields = (
expected_fields = [
"block_hashes",
"medium",
)
]
if _has_group_idx(BlockRemoved):
expected_fields.append("group_idx")
expected_fields = tuple(expected_fields)
actual_fields = BlockRemoved.__struct_fields__
assert actual_fields == expected_fields, (
......@@ -158,16 +169,19 @@ class TestVllmKvEventsApi:
"""
import msgspec
event = BlockStored(
block_hashes=[123, 456],
parent_block_hash=789,
token_ids=[1, 2, 3, 4],
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
extra_keys=None,
)
event_kwargs = {
"block_hashes": [123, 456],
"parent_block_hash": 789,
"token_ids": [1, 2, 3, 4],
"block_size": 16,
"lora_id": None,
"medium": "GPU",
"lora_name": None,
"extra_keys": None,
}
if _has_group_idx(BlockStored):
event_kwargs["group_idx"] = 0
event = BlockStored(**event_kwargs)
encoded = msgspec.msgpack.encode(event)
decoded = msgspec.msgpack.decode(encoded)
......@@ -178,9 +192,9 @@ class TestVllmKvEventsApi:
decoded[0] == "BlockStored"
), f"Expected tag 'BlockStored', got {decoded[0]}"
# Verify field count (tag + 8 fields = 9 elements)
assert len(decoded) == 9, (
f"Expected 9 elements (tag + 8 fields), got {len(decoded)}.\n"
expected_len = 10 if _has_group_idx(BlockStored) else 9
assert len(decoded) == expected_len, (
f"Expected {expected_len} elements, got {len(decoded)}.\n"
f"Decoded: {decoded}\n"
f"If field count changed, update Rust deserializers."
)
......@@ -194,22 +208,27 @@ class TestVllmKvEventsApi:
assert decoded[6] == "GPU", f"medium at wrong position: {decoded[6]}"
assert decoded[7] is None, f"lora_name at wrong position: {decoded[7]}"
assert decoded[8] is None, f"extra_keys at wrong position: {decoded[8]}"
if _has_group_idx(BlockStored):
assert decoded[9] == 0, f"group_idx at wrong position: {decoded[9]}"
def test_block_stored_tuple_extra_keys_serialization_format(self):
"""Verify multimodal tuple extra_keys keep the vLLM 0.19 wire shape."""
import msgspec
mm_hash = "0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210"
event = BlockStored(
block_hashes=[123],
parent_block_hash=None,
token_ids=[1, 2, 3, 4],
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
extra_keys=[((mm_hash, 7),)],
)
event_kwargs = {
"block_hashes": [123],
"parent_block_hash": None,
"token_ids": [1, 2, 3, 4],
"block_size": 16,
"lora_id": None,
"medium": "GPU",
"lora_name": None,
"extra_keys": [((mm_hash, 7),)],
}
if _has_group_idx(BlockStored):
event_kwargs["group_idx"] = 0
event = BlockStored(**event_kwargs)
decoded = msgspec.msgpack.decode(msgspec.msgpack.encode(event))
......@@ -218,3 +237,31 @@ class TestVllmKvEventsApi:
"vLLM multimodal extra_keys no longer serialize as nested tuple/list "
f"payloads. Decoded: {decoded[8]!r}"
)
if _has_group_idx(BlockStored):
assert decoded[9] == 0, f"group_idx at wrong position: {decoded[9]}"
def test_block_removed_serialization_format(self):
"""Verify BlockRemoved serializes to expected msgpack array format."""
import msgspec
event_kwargs = {
"block_hashes": [123, 456],
"medium": "GPU",
}
if _has_group_idx(BlockRemoved):
event_kwargs["group_idx"] = 0
event = BlockRemoved(**event_kwargs)
decoded = msgspec.msgpack.decode(msgspec.msgpack.encode(event))
assert decoded[0] == "BlockRemoved"
expected_len = 4 if _has_group_idx(BlockRemoved) else 3
assert len(decoded) == expected_len, (
f"Expected {expected_len} elements, got {len(decoded)}.\n"
f"Decoded: {decoded}\n"
f"If field count changed, update Rust deserializers."
)
assert decoded[1] == [123, 456], f"block_hashes at wrong position: {decoded[1]}"
assert decoded[2] == "GPU", f"medium at wrong position: {decoded[2]}"
if _has_group_idx(BlockRemoved):
assert decoded[3] == 0, f"group_idx at wrong position: {decoded[3]}"
......@@ -154,13 +154,15 @@ impl ListenerLoop {
.data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events {
let placement_event = convert_event(
let Some(placement_event) = convert_event(
raw_event,
seq,
block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank),
warning_count,
);
) else {
continue;
};
if !placement_event.placement.is_local_gpu() {
continue;
}
......@@ -223,13 +225,15 @@ impl ListenerLoop {
.data_parallel_rank
.map_or(self.dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events {
let placement_event = convert_event(
let Some(placement_event) = convert_event(
raw_event,
seq,
self.block_size,
WorkerWithDpRank::new(self.worker_id, effective_dp_rank),
&self.warning_count,
);
) else {
continue;
};
if !placement_event.placement.is_local_gpu() {
continue;
}
......
......@@ -99,6 +99,7 @@ pub enum RawKvEvent {
medium: Option<String>,
},
AllBlocksCleared,
Ignored,
}
/// Parse MM hash from extra_keys string:
......@@ -126,6 +127,17 @@ pub enum ExtraKeyItem {
Bool(bool),
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum BlockStoredTrailingField {
GroupIdx(u32),
BlockMmInfos(Vec<Option<BlockExtraInfo>>),
}
fn is_non_main_group(group_idx: Option<u32>) -> bool {
matches!(group_idx, Some(group_idx) if group_idx != 0)
}
/// Convert vLLM BlockStored extra_keys to block-level MM infos.
/// extra_keys is a list aligned with blocks:
/// - None => no MM content in that block
......@@ -224,6 +236,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let mut lora_name: Option<Option<String>> = None;
let mut extra_keys: Option<Option<Vec<Option<Vec<ExtraKeyItem>>>>> = None;
let mut block_mm_infos: Option<Option<Vec<Option<BlockExtraInfo>>>> = None;
let mut group_idx: Option<Option<u32>> = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
......@@ -254,6 +267,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
"block_mm_infos" => {
block_mm_infos = Some(map.next_value()?);
}
"group_idx" => {
group_idx = Some(map.next_value()?);
}
_ => {
map.next_value::<IgnoredAny>()?;
}
......@@ -278,6 +294,10 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
};
let block_size =
block_size.ok_or_else(|| de::Error::missing_field("block_size"))?;
let medium = medium.unwrap_or(None);
if is_non_main_group(group_idx.unwrap_or(None)) {
return Ok(RawKvEvent::Ignored);
}
let block_mm_infos = block_mm_infos
.unwrap_or(None)
.or_else(|| extra_keys_to_block_mm_infos(extra_keys.unwrap_or(None)));
......@@ -286,7 +306,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
parent_block_hash: parent_block_hash.unwrap_or(None),
token_ids: raw_token_ids,
block_size,
medium: medium.unwrap_or(None),
medium,
lora_name: lora_name.unwrap_or(None),
block_mm_infos,
is_eagle: Some(is_eagle),
......@@ -295,15 +315,20 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
Some("BlockRemoved") => {
let block_hashes =
block_hashes.ok_or_else(|| de::Error::missing_field("block_hashes"))?;
let medium = medium.unwrap_or(None);
if is_non_main_group(group_idx.unwrap_or(None)) {
return Ok(RawKvEvent::Ignored);
}
Ok(RawKvEvent::BlockRemoved {
block_hashes,
medium: medium.unwrap_or(None),
medium,
})
}
Some("AllBlocksCleared") => Ok(RawKvEvent::AllBlocksCleared),
Some("Ignored") => Ok(RawKvEvent::Ignored),
Some(other) => Err(de::Error::unknown_variant(
other,
&["BlockStored", "BlockRemoved", "AllBlocksCleared"],
&["BlockStored", "BlockRemoved", "AllBlocksCleared", "Ignored"],
)),
None => Err(de::Error::missing_field("type")),
}
......@@ -339,11 +364,29 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let lora_name: Option<String> = seq.next_element()?.unwrap_or(None);
let extra_keys: Option<Vec<Option<Vec<ExtraKeyItem>>>> =
seq.next_element()?.unwrap_or(None);
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> =
seq.next_element()?.unwrap_or(None);
let mut block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = None;
let mut group_idx: Option<u32> = None;
for _ in 0..2 {
let trailing: Option<BlockStoredTrailingField> =
seq.next_element()?.unwrap_or(None);
match trailing {
Some(BlockStoredTrailingField::GroupIdx(idx)) => {
group_idx = Some(idx);
}
Some(BlockStoredTrailingField::BlockMmInfos(infos)) => {
block_mm_infos = Some(infos);
}
None => {}
}
}
while seq.next_element::<IgnoredAny>()?.is_some() {}
if is_non_main_group(group_idx) {
return Ok(RawKvEvent::Ignored);
}
let block_mm_infos =
block_mm_infos.or_else(|| extra_keys_to_block_mm_infos(extra_keys));
......@@ -375,9 +418,14 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &"missing block_hashes"))?;
let medium: Option<String> = seq.next_element()?.unwrap_or(None);
let group_idx: Option<u32> = seq.next_element()?.unwrap_or(None);
while seq.next_element::<IgnoredAny>()?.is_some() {}
if is_non_main_group(group_idx) {
return Ok(RawKvEvent::Ignored);
}
Ok(RawKvEvent::BlockRemoved {
block_hashes,
medium,
......@@ -387,9 +435,13 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
while seq.next_element::<IgnoredAny>()?.is_some() {}
Ok(RawKvEvent::AllBlocksCleared)
}
"Ignored" => {
while seq.next_element::<IgnoredAny>()?.is_some() {}
Ok(RawKvEvent::Ignored)
}
other => Err(de::Error::unknown_variant(
other,
&["BlockStored", "BlockRemoved", "AllBlocksCleared"],
&["BlockStored", "BlockRemoved", "AllBlocksCleared", "Ignored"],
)),
}
}
......@@ -406,12 +458,13 @@ pub fn convert_event(
kv_block_size: u32,
worker: WorkerWithDpRank,
warning_count: &Arc<AtomicU32>,
) -> PlacementEvent {
) -> Option<PlacementEvent> {
let storage_tier = match &raw {
RawKvEvent::BlockStored { medium, .. } | RawKvEvent::BlockRemoved { medium, .. } => {
StorageTier::from_kv_medium_or_default(medium.as_deref())
}
RawKvEvent::AllBlocksCleared => StorageTier::Device,
RawKvEvent::Ignored => return None,
};
let dp_rank = worker.dp_rank;
let event = match raw {
......@@ -440,7 +493,7 @@ pub fn convert_event(
// Return an empty Removed instead of Cleared to avoid nuking
// the worker's entire index state. An empty Removed is a no-op
// in the radix tree (zero iterations, returns Ok(())).
return PlacementEvent::new(
return Some(PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
KvCacheEvent {
event_id,
......@@ -449,7 +502,7 @@ pub fn convert_event(
}),
dp_rank,
},
);
));
}
}
......@@ -498,12 +551,13 @@ pub fn convert_event(
data: KvCacheEventData::Cleared,
dp_rank,
},
RawKvEvent::Ignored => unreachable!("ignored events return before conversion"),
};
PlacementEvent::new(
Some(PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
event,
)
))
}
pub fn create_stored_block_from_parts(
......@@ -607,9 +661,16 @@ mod tests {
use std::sync::atomic::AtomicU32;
use rmp_serde::{from_slice, to_vec};
use rstest::rstest;
use super::*;
#[derive(Clone, Copy, Debug)]
enum TestEventKind {
BlockStored,
BlockRemoved,
}
#[test]
fn test_deserialize_bigram_block_stored_sequence() {
let raw_event = (
......@@ -640,6 +701,97 @@ mod tests {
}
}
fn block_stored_sequence_with_group_idx(group_idx: Option<u32>) -> Vec<u8> {
match group_idx {
Some(group_idx) => to_vec(&(
"BlockStored",
vec![BlockHashValue::Unsigned(11)],
Option::<BlockHashValue>::None,
vec![10u32, 11],
2usize,
Option::<u64>::None,
Option::<String>::None,
Option::<String>::None,
Option::<u8>::None,
group_idx,
))
.unwrap(),
None => to_vec(&(
"BlockStored",
vec![BlockHashValue::Unsigned(11)],
Option::<BlockHashValue>::None,
vec![10u32, 11],
2usize,
Option::<u64>::None,
Option::<String>::None,
Option::<String>::None,
))
.unwrap(),
}
}
fn block_removed_sequence_with_group_idx(group_idx: Option<u32>) -> Vec<u8> {
match group_idx {
Some(group_idx) => to_vec(&(
"BlockRemoved",
vec![BlockHashValue::Unsigned(11)],
Option::<String>::None,
group_idx,
))
.unwrap(),
None => to_vec(&(
"BlockRemoved",
vec![BlockHashValue::Unsigned(11)],
Option::<String>::None,
))
.unwrap(),
}
}
fn sequence_with_group_idx(event_kind: TestEventKind, group_idx: Option<u32>) -> Vec<u8> {
match event_kind {
TestEventKind::BlockStored => block_stored_sequence_with_group_idx(group_idx),
TestEventKind::BlockRemoved => block_removed_sequence_with_group_idx(group_idx),
}
}
fn assert_parsed_event_kind(event: RawKvEvent, expected_kind: TestEventKind) {
match (event, expected_kind) {
(RawKvEvent::BlockStored { .. }, TestEventKind::BlockStored)
| (RawKvEvent::BlockRemoved { .. }, TestEventKind::BlockRemoved) => {}
(event, expected_kind) => {
panic!("expected {expected_kind:?}, got {event:?}");
}
}
}
#[rstest]
#[case(TestEventKind::BlockStored)]
#[case(TestEventKind::BlockRemoved)]
fn test_deserialize_sequence_accepts_main_group_idx(#[case] event_kind: TestEventKind) {
let event: RawKvEvent = from_slice(&sequence_with_group_idx(event_kind, Some(0))).unwrap();
assert_parsed_event_kind(event, event_kind);
}
#[rstest]
#[case(TestEventKind::BlockStored)]
#[case(TestEventKind::BlockRemoved)]
fn test_deserialize_sequence_ignores_non_main_group_idx(#[case] event_kind: TestEventKind) {
let event: RawKvEvent = from_slice(&sequence_with_group_idx(event_kind, Some(1))).unwrap();
assert!(matches!(event, RawKvEvent::Ignored));
}
#[rstest]
#[case(TestEventKind::BlockStored)]
#[case(TestEventKind::BlockRemoved)]
fn test_deserialize_sequence_accepts_missing_group_idx(#[case] event_kind: TestEventKind) {
let event: RawKvEvent = from_slice(&sequence_with_group_idx(event_kind, None)).unwrap();
assert_parsed_event_kind(event, event_kind);
}
#[test]
fn test_convert_event_bigram_emits_eagle_windows() {
let raw_event = RawKvEvent::BlockStored {
......@@ -656,7 +808,7 @@ mod tests {
let placement_event =
convert_event(raw_event, 7, 2, WorkerWithDpRank::new(3, 0), &warning_count);
match placement_event.event.data {
match placement_event.unwrap().event.data {
KvCacheEventData::Stored(store_data) => {
assert_eq!(store_data.blocks.len(), 2);
assert_eq!(
......
......@@ -246,5 +246,7 @@ fn process_event(
tracing::debug!("Processing AllBlocksCleared");
tracker.handle_clear_all();
}
RawKvEvent::Ignored => {}
}
}
......@@ -112,7 +112,8 @@ mod test_event_processing {
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
)
.unwrap();
assert!(matches!(out.event.data, KvCacheEventData::Stored(_)));
}
......@@ -149,14 +150,16 @@ mod test_event_processing {
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
)
.unwrap();
let lora_out = convert_event(
lora_evt,
2,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
)
.unwrap();
let base_hash = match &base_out.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
......@@ -205,14 +208,16 @@ mod test_event_processing {
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
)
.unwrap();
let out2 = convert_event(
evt2,
2,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
)
.unwrap();
let hash1 = match &out1.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
......@@ -297,7 +302,8 @@ mod test_event_processing {
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
)
.unwrap();
assert!(matches!(out.event.data, KvCacheEventData::Removed(_)));
}
......@@ -312,7 +318,8 @@ mod test_event_processing {
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
)
.unwrap();
assert!(matches!(out.event.data, KvCacheEventData::Cleared));
}
......
......@@ -103,10 +103,16 @@ pub(super) async fn start_zmq_listener(
let dp_rank = batch.data_parallel_rank.unwrap_or(0).cast_unsigned();
for raw_event in batch.events {
if matches!(raw_event, RawKvEvent::Ignored) {
continue;
}
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let event =
convert_event(raw_event, event_id, kv_block_size, worker, &warning_count);
let Some(event) =
convert_event(raw_event, event_id, kv_block_size, worker, &warning_count)
else {
continue;
};
if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped");
break 'main String::from("channel receiver dropped");
......
......@@ -72,9 +72,11 @@ enum ZmqRawKvEvent {
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: u32,
group_idx: u32,
},
BlockRemoved {
block_hashes: Vec<u64>,
group_idx: u32,
},
}
......@@ -277,11 +279,15 @@ fn convert_to_zmq_events(
parent_block_hash,
token_ids,
block_size,
group_idx: 0,
}]
}
KvCacheEventData::Removed(remove_data) => {
let block_hashes: Vec<u64> = remove_data.block_hashes.iter().map(|h| h.0).collect();
vec![ZmqRawKvEvent::BlockRemoved { block_hashes }]
vec![ZmqRawKvEvent::BlockRemoved {
block_hashes,
group_idx: 0,
}]
}
KvCacheEventData::Cleared => vec![],
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment