Unverified Commit 01002df7 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): filter non-main ZMQ KV event groups (cherry-pick) (#8705)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 7d4572f9
...@@ -46,6 +46,10 @@ pytestmark = [ ...@@ -46,6 +46,10 @@ pytestmark = [
] ]
def _has_group_idx(event_cls):
return "group_idx" in event_cls.__struct_fields__
class TestVllmKvEventsApi: class TestVllmKvEventsApi:
"""Test vLLM KV events API compatibility.""" """Test vLLM KV events API compatibility."""
...@@ -61,10 +65,11 @@ class TestVllmKvEventsApi: ...@@ -61,10 +65,11 @@ class TestVllmKvEventsApi:
6. medium 6. medium
7. lora_name (added in vLLM 0.14.0) 7. lora_name (added in vLLM 0.14.0)
8. extra_keys (added in vLLM 0.17.0) 8. extra_keys (added in vLLM 0.17.0)
9. group_idx (added for hybrid KV cache groups; optional for older vLLM)
If vLLM adds/removes/reorders fields, this test will fail. If vLLM adds/removes/reorders fields, this test will fail.
""" """
expected_fields = ( expected_fields = [
"block_hashes", "block_hashes",
"parent_block_hash", "parent_block_hash",
"token_ids", "token_ids",
...@@ -73,7 +78,10 @@ class TestVllmKvEventsApi: ...@@ -73,7 +78,10 @@ class TestVllmKvEventsApi:
"medium", "medium",
"lora_name", "lora_name",
"extra_keys", "extra_keys",
) ]
if _has_group_idx(BlockStored):
expected_fields.append("group_idx")
expected_fields = tuple(expected_fields)
actual_fields = BlockStored.__struct_fields__ actual_fields = BlockStored.__struct_fields__
assert actual_fields == expected_fields, ( assert actual_fields == expected_fields, (
...@@ -88,10 +96,13 @@ class TestVllmKvEventsApi: ...@@ -88,10 +96,13 @@ class TestVllmKvEventsApi:
def test_block_removed_fields(self): def test_block_removed_fields(self):
"""Verify BlockRemoved has expected fields in expected order.""" """Verify BlockRemoved has expected fields in expected order."""
expected_fields = ( expected_fields = [
"block_hashes", "block_hashes",
"medium", "medium",
) ]
if _has_group_idx(BlockRemoved):
expected_fields.append("group_idx")
expected_fields = tuple(expected_fields)
actual_fields = BlockRemoved.__struct_fields__ actual_fields = BlockRemoved.__struct_fields__
assert actual_fields == expected_fields, ( assert actual_fields == expected_fields, (
...@@ -158,16 +169,19 @@ class TestVllmKvEventsApi: ...@@ -158,16 +169,19 @@ class TestVllmKvEventsApi:
""" """
import msgspec import msgspec
event = BlockStored( event_kwargs = {
block_hashes=[123, 456], "block_hashes": [123, 456],
parent_block_hash=789, "parent_block_hash": 789,
token_ids=[1, 2, 3, 4], "token_ids": [1, 2, 3, 4],
block_size=16, "block_size": 16,
lora_id=None, "lora_id": None,
medium="GPU", "medium": "GPU",
lora_name=None, "lora_name": None,
extra_keys=None, "extra_keys": None,
) }
if _has_group_idx(BlockStored):
event_kwargs["group_idx"] = 0
event = BlockStored(**event_kwargs)
encoded = msgspec.msgpack.encode(event) encoded = msgspec.msgpack.encode(event)
decoded = msgspec.msgpack.decode(encoded) decoded = msgspec.msgpack.decode(encoded)
...@@ -178,9 +192,9 @@ class TestVllmKvEventsApi: ...@@ -178,9 +192,9 @@ class TestVllmKvEventsApi:
decoded[0] == "BlockStored" decoded[0] == "BlockStored"
), f"Expected tag 'BlockStored', got {decoded[0]}" ), f"Expected tag 'BlockStored', got {decoded[0]}"
# Verify field count (tag + 8 fields = 9 elements) expected_len = 10 if _has_group_idx(BlockStored) else 9
assert len(decoded) == 9, ( assert len(decoded) == expected_len, (
f"Expected 9 elements (tag + 8 fields), got {len(decoded)}.\n" f"Expected {expected_len} elements, got {len(decoded)}.\n"
f"Decoded: {decoded}\n" f"Decoded: {decoded}\n"
f"If field count changed, update Rust deserializers." f"If field count changed, update Rust deserializers."
) )
...@@ -194,22 +208,27 @@ class TestVllmKvEventsApi: ...@@ -194,22 +208,27 @@ class TestVllmKvEventsApi:
assert decoded[6] == "GPU", f"medium at wrong position: {decoded[6]}" assert decoded[6] == "GPU", f"medium at wrong position: {decoded[6]}"
assert decoded[7] is None, f"lora_name at wrong position: {decoded[7]}" assert decoded[7] is None, f"lora_name at wrong position: {decoded[7]}"
assert decoded[8] is None, f"extra_keys at wrong position: {decoded[8]}" assert decoded[8] is None, f"extra_keys at wrong position: {decoded[8]}"
if _has_group_idx(BlockStored):
assert decoded[9] == 0, f"group_idx at wrong position: {decoded[9]}"
def test_block_stored_tuple_extra_keys_serialization_format(self): def test_block_stored_tuple_extra_keys_serialization_format(self):
"""Verify multimodal tuple extra_keys keep the vLLM 0.19 wire shape.""" """Verify multimodal tuple extra_keys keep the vLLM 0.19 wire shape."""
import msgspec import msgspec
mm_hash = "0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210" mm_hash = "0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210"
event = BlockStored( event_kwargs = {
block_hashes=[123], "block_hashes": [123],
parent_block_hash=None, "parent_block_hash": None,
token_ids=[1, 2, 3, 4], "token_ids": [1, 2, 3, 4],
block_size=16, "block_size": 16,
lora_id=None, "lora_id": None,
medium="GPU", "medium": "GPU",
lora_name=None, "lora_name": None,
extra_keys=[((mm_hash, 7),)], "extra_keys": [((mm_hash, 7),)],
) }
if _has_group_idx(BlockStored):
event_kwargs["group_idx"] = 0
event = BlockStored(**event_kwargs)
decoded = msgspec.msgpack.decode(msgspec.msgpack.encode(event)) decoded = msgspec.msgpack.decode(msgspec.msgpack.encode(event))
...@@ -218,3 +237,31 @@ class TestVllmKvEventsApi: ...@@ -218,3 +237,31 @@ class TestVllmKvEventsApi:
"vLLM multimodal extra_keys no longer serialize as nested tuple/list " "vLLM multimodal extra_keys no longer serialize as nested tuple/list "
f"payloads. Decoded: {decoded[8]!r}" f"payloads. Decoded: {decoded[8]!r}"
) )
if _has_group_idx(BlockStored):
assert decoded[9] == 0, f"group_idx at wrong position: {decoded[9]}"
def test_block_removed_serialization_format(self):
"""Verify BlockRemoved serializes to expected msgpack array format."""
import msgspec
event_kwargs = {
"block_hashes": [123, 456],
"medium": "GPU",
}
if _has_group_idx(BlockRemoved):
event_kwargs["group_idx"] = 0
event = BlockRemoved(**event_kwargs)
decoded = msgspec.msgpack.decode(msgspec.msgpack.encode(event))
assert decoded[0] == "BlockRemoved"
expected_len = 4 if _has_group_idx(BlockRemoved) else 3
assert len(decoded) == expected_len, (
f"Expected {expected_len} elements, got {len(decoded)}.\n"
f"Decoded: {decoded}\n"
f"If field count changed, update Rust deserializers."
)
assert decoded[1] == [123, 456], f"block_hashes at wrong position: {decoded[1]}"
assert decoded[2] == "GPU", f"medium at wrong position: {decoded[2]}"
if _has_group_idx(BlockRemoved):
assert decoded[3] == 0, f"group_idx at wrong position: {decoded[3]}"
...@@ -154,13 +154,15 @@ impl ListenerLoop { ...@@ -154,13 +154,15 @@ impl ListenerLoop {
.data_parallel_rank .data_parallel_rank
.map_or(dp_rank, |rank| rank.cast_unsigned()); .map_or(dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events { for raw_event in batch.events {
let placement_event = convert_event( let Some(placement_event) = convert_event(
raw_event, raw_event,
seq, seq,
block_size, block_size,
WorkerWithDpRank::new(worker_id, effective_dp_rank), WorkerWithDpRank::new(worker_id, effective_dp_rank),
warning_count, warning_count,
); ) else {
continue;
};
if !placement_event.placement.is_local_gpu() { if !placement_event.placement.is_local_gpu() {
continue; continue;
} }
...@@ -223,13 +225,15 @@ impl ListenerLoop { ...@@ -223,13 +225,15 @@ impl ListenerLoop {
.data_parallel_rank .data_parallel_rank
.map_or(self.dp_rank, |rank| rank.cast_unsigned()); .map_or(self.dp_rank, |rank| rank.cast_unsigned());
for raw_event in batch.events { for raw_event in batch.events {
let placement_event = convert_event( let Some(placement_event) = convert_event(
raw_event, raw_event,
seq, seq,
self.block_size, self.block_size,
WorkerWithDpRank::new(self.worker_id, effective_dp_rank), WorkerWithDpRank::new(self.worker_id, effective_dp_rank),
&self.warning_count, &self.warning_count,
); ) else {
continue;
};
if !placement_event.placement.is_local_gpu() { if !placement_event.placement.is_local_gpu() {
continue; continue;
} }
......
...@@ -99,6 +99,7 @@ pub enum RawKvEvent { ...@@ -99,6 +99,7 @@ pub enum RawKvEvent {
medium: Option<String>, medium: Option<String>,
}, },
AllBlocksCleared, AllBlocksCleared,
Ignored,
} }
/// Parse MM hash from extra_keys string: /// Parse MM hash from extra_keys string:
...@@ -126,6 +127,17 @@ pub enum ExtraKeyItem { ...@@ -126,6 +127,17 @@ pub enum ExtraKeyItem {
Bool(bool), Bool(bool),
} }
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum BlockStoredTrailingField {
GroupIdx(u32),
BlockMmInfos(Vec<Option<BlockExtraInfo>>),
}
fn is_non_main_group(group_idx: Option<u32>) -> bool {
matches!(group_idx, Some(group_idx) if group_idx != 0)
}
/// Convert vLLM BlockStored extra_keys to block-level MM infos. /// Convert vLLM BlockStored extra_keys to block-level MM infos.
/// extra_keys is a list aligned with blocks: /// extra_keys is a list aligned with blocks:
/// - None => no MM content in that block /// - None => no MM content in that block
...@@ -224,6 +236,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -224,6 +236,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let mut lora_name: Option<Option<String>> = None; let mut lora_name: Option<Option<String>> = None;
let mut extra_keys: Option<Option<Vec<Option<Vec<ExtraKeyItem>>>>> = None; let mut extra_keys: Option<Option<Vec<Option<Vec<ExtraKeyItem>>>>> = None;
let mut block_mm_infos: Option<Option<Vec<Option<BlockExtraInfo>>>> = None; let mut block_mm_infos: Option<Option<Vec<Option<BlockExtraInfo>>>> = None;
let mut group_idx: Option<Option<u32>> = None;
while let Some(key) = map.next_key::<String>()? { while let Some(key) = map.next_key::<String>()? {
match key.as_str() { match key.as_str() {
...@@ -254,6 +267,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -254,6 +267,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
"block_mm_infos" => { "block_mm_infos" => {
block_mm_infos = Some(map.next_value()?); block_mm_infos = Some(map.next_value()?);
} }
"group_idx" => {
group_idx = Some(map.next_value()?);
}
_ => { _ => {
map.next_value::<IgnoredAny>()?; map.next_value::<IgnoredAny>()?;
} }
...@@ -278,6 +294,10 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -278,6 +294,10 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
}; };
let block_size = let block_size =
block_size.ok_or_else(|| de::Error::missing_field("block_size"))?; block_size.ok_or_else(|| de::Error::missing_field("block_size"))?;
let medium = medium.unwrap_or(None);
if is_non_main_group(group_idx.unwrap_or(None)) {
return Ok(RawKvEvent::Ignored);
}
let block_mm_infos = block_mm_infos let block_mm_infos = block_mm_infos
.unwrap_or(None) .unwrap_or(None)
.or_else(|| extra_keys_to_block_mm_infos(extra_keys.unwrap_or(None))); .or_else(|| extra_keys_to_block_mm_infos(extra_keys.unwrap_or(None)));
...@@ -286,7 +306,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -286,7 +306,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
parent_block_hash: parent_block_hash.unwrap_or(None), parent_block_hash: parent_block_hash.unwrap_or(None),
token_ids: raw_token_ids, token_ids: raw_token_ids,
block_size, block_size,
medium: medium.unwrap_or(None), medium,
lora_name: lora_name.unwrap_or(None), lora_name: lora_name.unwrap_or(None),
block_mm_infos, block_mm_infos,
is_eagle: Some(is_eagle), is_eagle: Some(is_eagle),
...@@ -295,15 +315,20 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -295,15 +315,20 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
Some("BlockRemoved") => { Some("BlockRemoved") => {
let block_hashes = let block_hashes =
block_hashes.ok_or_else(|| de::Error::missing_field("block_hashes"))?; block_hashes.ok_or_else(|| de::Error::missing_field("block_hashes"))?;
let medium = medium.unwrap_or(None);
if is_non_main_group(group_idx.unwrap_or(None)) {
return Ok(RawKvEvent::Ignored);
}
Ok(RawKvEvent::BlockRemoved { Ok(RawKvEvent::BlockRemoved {
block_hashes, block_hashes,
medium: medium.unwrap_or(None), medium,
}) })
} }
Some("AllBlocksCleared") => Ok(RawKvEvent::AllBlocksCleared), Some("AllBlocksCleared") => Ok(RawKvEvent::AllBlocksCleared),
Some("Ignored") => Ok(RawKvEvent::Ignored),
Some(other) => Err(de::Error::unknown_variant( Some(other) => Err(de::Error::unknown_variant(
other, other,
&["BlockStored", "BlockRemoved", "AllBlocksCleared"], &["BlockStored", "BlockRemoved", "AllBlocksCleared", "Ignored"],
)), )),
None => Err(de::Error::missing_field("type")), None => Err(de::Error::missing_field("type")),
} }
...@@ -339,11 +364,29 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -339,11 +364,29 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let lora_name: Option<String> = seq.next_element()?.unwrap_or(None); let lora_name: Option<String> = seq.next_element()?.unwrap_or(None);
let extra_keys: Option<Vec<Option<Vec<ExtraKeyItem>>>> = let extra_keys: Option<Vec<Option<Vec<ExtraKeyItem>>>> =
seq.next_element()?.unwrap_or(None); seq.next_element()?.unwrap_or(None);
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = let mut block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> = None;
seq.next_element()?.unwrap_or(None); let mut group_idx: Option<u32> = None;
for _ in 0..2 {
let trailing: Option<BlockStoredTrailingField> =
seq.next_element()?.unwrap_or(None);
match trailing {
Some(BlockStoredTrailingField::GroupIdx(idx)) => {
group_idx = Some(idx);
}
Some(BlockStoredTrailingField::BlockMmInfos(infos)) => {
block_mm_infos = Some(infos);
}
None => {}
}
}
while seq.next_element::<IgnoredAny>()?.is_some() {} while seq.next_element::<IgnoredAny>()?.is_some() {}
if is_non_main_group(group_idx) {
return Ok(RawKvEvent::Ignored);
}
let block_mm_infos = let block_mm_infos =
block_mm_infos.or_else(|| extra_keys_to_block_mm_infos(extra_keys)); block_mm_infos.or_else(|| extra_keys_to_block_mm_infos(extra_keys));
...@@ -375,9 +418,14 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -375,9 +418,14 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
.next_element()? .next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &"missing block_hashes"))?; .ok_or_else(|| de::Error::invalid_length(1, &"missing block_hashes"))?;
let medium: Option<String> = seq.next_element()?.unwrap_or(None); let medium: Option<String> = seq.next_element()?.unwrap_or(None);
let group_idx: Option<u32> = seq.next_element()?.unwrap_or(None);
while seq.next_element::<IgnoredAny>()?.is_some() {} while seq.next_element::<IgnoredAny>()?.is_some() {}
if is_non_main_group(group_idx) {
return Ok(RawKvEvent::Ignored);
}
Ok(RawKvEvent::BlockRemoved { Ok(RawKvEvent::BlockRemoved {
block_hashes, block_hashes,
medium, medium,
...@@ -387,9 +435,13 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -387,9 +435,13 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
while seq.next_element::<IgnoredAny>()?.is_some() {} while seq.next_element::<IgnoredAny>()?.is_some() {}
Ok(RawKvEvent::AllBlocksCleared) Ok(RawKvEvent::AllBlocksCleared)
} }
"Ignored" => {
while seq.next_element::<IgnoredAny>()?.is_some() {}
Ok(RawKvEvent::Ignored)
}
other => Err(de::Error::unknown_variant( other => Err(de::Error::unknown_variant(
other, other,
&["BlockStored", "BlockRemoved", "AllBlocksCleared"], &["BlockStored", "BlockRemoved", "AllBlocksCleared", "Ignored"],
)), )),
} }
} }
...@@ -406,12 +458,13 @@ pub fn convert_event( ...@@ -406,12 +458,13 @@ pub fn convert_event(
kv_block_size: u32, kv_block_size: u32,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
) -> PlacementEvent { ) -> Option<PlacementEvent> {
let storage_tier = match &raw { let storage_tier = match &raw {
RawKvEvent::BlockStored { medium, .. } | RawKvEvent::BlockRemoved { medium, .. } => { RawKvEvent::BlockStored { medium, .. } | RawKvEvent::BlockRemoved { medium, .. } => {
StorageTier::from_kv_medium_or_default(medium.as_deref()) StorageTier::from_kv_medium_or_default(medium.as_deref())
} }
RawKvEvent::AllBlocksCleared => StorageTier::Device, RawKvEvent::AllBlocksCleared => StorageTier::Device,
RawKvEvent::Ignored => return None,
}; };
let dp_rank = worker.dp_rank; let dp_rank = worker.dp_rank;
let event = match raw { let event = match raw {
...@@ -440,7 +493,7 @@ pub fn convert_event( ...@@ -440,7 +493,7 @@ pub fn convert_event(
// Return an empty Removed instead of Cleared to avoid nuking // Return an empty Removed instead of Cleared to avoid nuking
// the worker's entire index state. An empty Removed is a no-op // the worker's entire index state. An empty Removed is a no-op
// in the radix tree (zero iterations, returns Ok(())). // in the radix tree (zero iterations, returns Ok(())).
return PlacementEvent::new( return Some(PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier), Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
KvCacheEvent { KvCacheEvent {
event_id, event_id,
...@@ -449,7 +502,7 @@ pub fn convert_event( ...@@ -449,7 +502,7 @@ pub fn convert_event(
}), }),
dp_rank, dp_rank,
}, },
); ));
} }
} }
...@@ -498,12 +551,13 @@ pub fn convert_event( ...@@ -498,12 +551,13 @@ pub fn convert_event(
data: KvCacheEventData::Cleared, data: KvCacheEventData::Cleared,
dp_rank, dp_rank,
}, },
RawKvEvent::Ignored => unreachable!("ignored events return before conversion"),
}; };
PlacementEvent::new( Some(PlacementEvent::new(
Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier), Placement::local_worker(worker.worker_id, worker.dp_rank, storage_tier),
event, event,
) ))
} }
pub fn create_stored_block_from_parts( pub fn create_stored_block_from_parts(
...@@ -607,9 +661,16 @@ mod tests { ...@@ -607,9 +661,16 @@ mod tests {
use std::sync::atomic::AtomicU32; use std::sync::atomic::AtomicU32;
use rmp_serde::{from_slice, to_vec}; use rmp_serde::{from_slice, to_vec};
use rstest::rstest;
use super::*; use super::*;
#[derive(Clone, Copy, Debug)]
enum TestEventKind {
BlockStored,
BlockRemoved,
}
#[test] #[test]
fn test_deserialize_bigram_block_stored_sequence() { fn test_deserialize_bigram_block_stored_sequence() {
let raw_event = ( let raw_event = (
...@@ -640,6 +701,97 @@ mod tests { ...@@ -640,6 +701,97 @@ mod tests {
} }
} }
fn block_stored_sequence_with_group_idx(group_idx: Option<u32>) -> Vec<u8> {
match group_idx {
Some(group_idx) => to_vec(&(
"BlockStored",
vec![BlockHashValue::Unsigned(11)],
Option::<BlockHashValue>::None,
vec![10u32, 11],
2usize,
Option::<u64>::None,
Option::<String>::None,
Option::<String>::None,
Option::<u8>::None,
group_idx,
))
.unwrap(),
None => to_vec(&(
"BlockStored",
vec![BlockHashValue::Unsigned(11)],
Option::<BlockHashValue>::None,
vec![10u32, 11],
2usize,
Option::<u64>::None,
Option::<String>::None,
Option::<String>::None,
))
.unwrap(),
}
}
fn block_removed_sequence_with_group_idx(group_idx: Option<u32>) -> Vec<u8> {
match group_idx {
Some(group_idx) => to_vec(&(
"BlockRemoved",
vec![BlockHashValue::Unsigned(11)],
Option::<String>::None,
group_idx,
))
.unwrap(),
None => to_vec(&(
"BlockRemoved",
vec![BlockHashValue::Unsigned(11)],
Option::<String>::None,
))
.unwrap(),
}
}
fn sequence_with_group_idx(event_kind: TestEventKind, group_idx: Option<u32>) -> Vec<u8> {
match event_kind {
TestEventKind::BlockStored => block_stored_sequence_with_group_idx(group_idx),
TestEventKind::BlockRemoved => block_removed_sequence_with_group_idx(group_idx),
}
}
fn assert_parsed_event_kind(event: RawKvEvent, expected_kind: TestEventKind) {
match (event, expected_kind) {
(RawKvEvent::BlockStored { .. }, TestEventKind::BlockStored)
| (RawKvEvent::BlockRemoved { .. }, TestEventKind::BlockRemoved) => {}
(event, expected_kind) => {
panic!("expected {expected_kind:?}, got {event:?}");
}
}
}
#[rstest]
#[case(TestEventKind::BlockStored)]
#[case(TestEventKind::BlockRemoved)]
fn test_deserialize_sequence_accepts_main_group_idx(#[case] event_kind: TestEventKind) {
let event: RawKvEvent = from_slice(&sequence_with_group_idx(event_kind, Some(0))).unwrap();
assert_parsed_event_kind(event, event_kind);
}
#[rstest]
#[case(TestEventKind::BlockStored)]
#[case(TestEventKind::BlockRemoved)]
fn test_deserialize_sequence_ignores_non_main_group_idx(#[case] event_kind: TestEventKind) {
let event: RawKvEvent = from_slice(&sequence_with_group_idx(event_kind, Some(1))).unwrap();
assert!(matches!(event, RawKvEvent::Ignored));
}
#[rstest]
#[case(TestEventKind::BlockStored)]
#[case(TestEventKind::BlockRemoved)]
fn test_deserialize_sequence_accepts_missing_group_idx(#[case] event_kind: TestEventKind) {
let event: RawKvEvent = from_slice(&sequence_with_group_idx(event_kind, None)).unwrap();
assert_parsed_event_kind(event, event_kind);
}
#[test] #[test]
fn test_convert_event_bigram_emits_eagle_windows() { fn test_convert_event_bigram_emits_eagle_windows() {
let raw_event = RawKvEvent::BlockStored { let raw_event = RawKvEvent::BlockStored {
...@@ -656,7 +808,7 @@ mod tests { ...@@ -656,7 +808,7 @@ mod tests {
let placement_event = let placement_event =
convert_event(raw_event, 7, 2, WorkerWithDpRank::new(3, 0), &warning_count); convert_event(raw_event, 7, 2, WorkerWithDpRank::new(3, 0), &warning_count);
match placement_event.event.data { match placement_event.unwrap().event.data {
KvCacheEventData::Stored(store_data) => { KvCacheEventData::Stored(store_data) => {
assert_eq!(store_data.blocks.len(), 2); assert_eq!(store_data.blocks.len(), 2);
assert_eq!( assert_eq!(
......
...@@ -246,5 +246,7 @@ fn process_event( ...@@ -246,5 +246,7 @@ fn process_event(
tracing::debug!("Processing AllBlocksCleared"); tracing::debug!("Processing AllBlocksCleared");
tracker.handle_clear_all(); tracker.handle_clear_all();
} }
RawKvEvent::Ignored => {}
} }
} }
...@@ -112,7 +112,8 @@ mod test_event_processing { ...@@ -112,7 +112,8 @@ mod test_event_processing {
kv_block_size, kv_block_size,
WorkerWithDpRank::from_worker_id(1), WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)), &Arc::new(AtomicU32::new(0)),
); )
.unwrap();
assert!(matches!(out.event.data, KvCacheEventData::Stored(_))); assert!(matches!(out.event.data, KvCacheEventData::Stored(_)));
} }
...@@ -149,14 +150,16 @@ mod test_event_processing { ...@@ -149,14 +150,16 @@ mod test_event_processing {
kv_block_size, kv_block_size,
WorkerWithDpRank::from_worker_id(1), WorkerWithDpRank::from_worker_id(1),
&wc, &wc,
); )
.unwrap();
let lora_out = convert_event( let lora_out = convert_event(
lora_evt, lora_evt,
2, 2,
kv_block_size, kv_block_size,
WorkerWithDpRank::from_worker_id(1), WorkerWithDpRank::from_worker_id(1),
&wc, &wc,
); )
.unwrap();
let base_hash = match &base_out.event.data { let base_hash = match &base_out.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash, KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
...@@ -205,14 +208,16 @@ mod test_event_processing { ...@@ -205,14 +208,16 @@ mod test_event_processing {
kv_block_size, kv_block_size,
WorkerWithDpRank::from_worker_id(1), WorkerWithDpRank::from_worker_id(1),
&wc, &wc,
); )
.unwrap();
let out2 = convert_event( let out2 = convert_event(
evt2, evt2,
2, 2,
kv_block_size, kv_block_size,
WorkerWithDpRank::from_worker_id(1), WorkerWithDpRank::from_worker_id(1),
&wc, &wc,
); )
.unwrap();
let hash1 = match &out1.event.data { let hash1 = match &out1.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash, KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
...@@ -297,7 +302,8 @@ mod test_event_processing { ...@@ -297,7 +302,8 @@ mod test_event_processing {
kv_block_size, kv_block_size,
WorkerWithDpRank::from_worker_id(1), WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)), &Arc::new(AtomicU32::new(0)),
); )
.unwrap();
assert!(matches!(out.event.data, KvCacheEventData::Removed(_))); assert!(matches!(out.event.data, KvCacheEventData::Removed(_)));
} }
...@@ -312,7 +318,8 @@ mod test_event_processing { ...@@ -312,7 +318,8 @@ mod test_event_processing {
kv_block_size, kv_block_size,
WorkerWithDpRank::from_worker_id(1), WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)), &Arc::new(AtomicU32::new(0)),
); )
.unwrap();
assert!(matches!(out.event.data, KvCacheEventData::Cleared)); assert!(matches!(out.event.data, KvCacheEventData::Cleared));
} }
......
...@@ -103,10 +103,16 @@ pub(super) async fn start_zmq_listener( ...@@ -103,10 +103,16 @@ pub(super) async fn start_zmq_listener(
let dp_rank = batch.data_parallel_rank.unwrap_or(0).cast_unsigned(); let dp_rank = batch.data_parallel_rank.unwrap_or(0).cast_unsigned();
for raw_event in batch.events { for raw_event in batch.events {
if matches!(raw_event, RawKvEvent::Ignored) {
continue;
}
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst); let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
let worker = WorkerWithDpRank::new(worker_id, dp_rank); let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let event = let Some(event) =
convert_event(raw_event, event_id, kv_block_size, worker, &warning_count); convert_event(raw_event, event_id, kv_block_size, worker, &warning_count)
else {
continue;
};
if tx.send(event).is_err() { if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped"); tracing::warn!("Failed to send message to channel - receiver dropped");
break 'main String::from("channel receiver dropped"); break 'main String::from("channel receiver dropped");
......
...@@ -72,9 +72,11 @@ enum ZmqRawKvEvent { ...@@ -72,9 +72,11 @@ enum ZmqRawKvEvent {
parent_block_hash: Option<u64>, parent_block_hash: Option<u64>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
block_size: u32, block_size: u32,
group_idx: u32,
}, },
BlockRemoved { BlockRemoved {
block_hashes: Vec<u64>, block_hashes: Vec<u64>,
group_idx: u32,
}, },
} }
...@@ -277,11 +279,15 @@ fn convert_to_zmq_events( ...@@ -277,11 +279,15 @@ fn convert_to_zmq_events(
parent_block_hash, parent_block_hash,
token_ids, token_ids,
block_size, block_size,
group_idx: 0,
}] }]
} }
KvCacheEventData::Removed(remove_data) => { KvCacheEventData::Removed(remove_data) => {
let block_hashes: Vec<u64> = remove_data.block_hashes.iter().map(|h| h.0).collect(); let block_hashes: Vec<u64> = remove_data.block_hashes.iter().map(|h| h.0).collect();
vec![ZmqRawKvEvent::BlockRemoved { block_hashes }] vec![ZmqRawKvEvent::BlockRemoved {
block_hashes,
group_idx: 0,
}]
} }
KvCacheEventData::Cleared => vec![], KvCacheEventData::Cleared => vec![],
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment