Unverified Commit 9d10e050 authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

fix: Fix vLLM KVBM integration test (#6768)

parent 93fa4d4d
...@@ -14,6 +14,8 @@ use tokio::task::JoinHandle; ...@@ -14,6 +14,8 @@ use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket}; use zeromq::{Socket, SocketRecv, SubSocket};
use dynamo_kv_router::zmq_wire::RawKvEvent;
use super::tracker::{CacheStatusTracker, EventSource, StorageTier}; use super::tracker::{CacheStatusTracker, EventSource, StorageTier};
/// Event batch received from vLLM/TensorRT-LLM (array format) /// Event batch received from vLLM/TensorRT-LLM (array format)
...@@ -24,7 +26,7 @@ use super::tracker::{CacheStatusTracker, EventSource, StorageTier}; ...@@ -24,7 +26,7 @@ use super::tracker::{CacheStatusTracker, EventSource, StorageTier};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct VllmEventBatch( struct VllmEventBatch(
f64, // ts f64, // ts
Vec<VllmRawEvent>, // events Vec<RawKvEvent>, // events — reuses the same custom deserializer as the router publisher
Option<i32>, // data_parallel_rank Option<i32>, // data_parallel_rank
); );
...@@ -33,7 +35,7 @@ impl VllmEventBatch { ...@@ -33,7 +35,7 @@ impl VllmEventBatch {
self.0 self.0
} }
fn events(&self) -> &Vec<VllmRawEvent> { fn events(&self) -> &Vec<RawKvEvent> {
&self.1 &self.1
} }
...@@ -42,124 +44,6 @@ impl VllmEventBatch { ...@@ -42,124 +44,6 @@ impl VllmEventBatch {
} }
} }
/// Block hash can be either an integer or a string (bytes hex-encoded)
///
/// Note: Integers can be u64 or i64 (msgpack compatibility) but we convert to u64 for internal use.
/// - vLLM uses u64 block hashes
/// - TensorRT-LLM uses i64 block hashes (signed integers)
#[derive(Debug, Clone)]
enum BlockHash {
IntU64(u64),
IntI64(i64), // Added for TensorRT-LLM support (uses signed i64 hashes)
Str(String),
}
impl<'de> serde::Deserialize<'de> for BlockHash {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, Visitor};
use std::fmt;
struct BlockHashVisitor;
impl<'de> Visitor<'de> for BlockHashVisitor {
type Value = BlockHash;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("an integer or a string")
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(BlockHash::IntU64(value))
}
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(BlockHash::IntI64(value))
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(BlockHash::Str(value.to_string()))
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(BlockHash::Str(value))
}
}
deserializer.deserialize_any(BlockHashVisitor)
}
}
impl BlockHash {
/// Convert to u64, handling both signed and unsigned integers
/// Returns None if the hash cannot be converted (e.g., invalid hex string)
/// This avoids silently collapsing invalid hashes to 0, which could cause collisions
fn to_u64(&self) -> Option<u64> {
match self {
BlockHash::IntU64(n) => Some(*n),
BlockHash::IntI64(n) => {
// Convert signed i64 back to unsigned u64 (two's complement)
// Rust's `as u64` automatically handles two's complement conversion
Some(*n as u64)
}
BlockHash::Str(s) => {
// Try to parse as hex string, return None on failure
// This avoids silently mapping invalid hashes to 0, which could cause collisions
u64::from_str_radix(s, 16).ok()
}
}
}
}
impl std::fmt::Display for BlockHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BlockHash::IntU64(n) => write!(f, "{}", n),
BlockHash::IntI64(n) => write!(f, "{}", n),
BlockHash::Str(s) => write!(f, "{}", s),
}
}
}
/// Raw vLLM event format (preserves all data including token_ids)
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
enum VllmRawEvent {
#[serde(rename = "BlockStored")]
BlockStored {
block_hashes: Vec<BlockHash>,
parent_block_hash: Option<BlockHash>,
token_ids: Vec<i32>,
block_size: i32,
#[serde(default)]
medium: Option<String>,
#[serde(default)]
lora_name: Option<String>,
},
#[serde(rename = "BlockRemoved")]
BlockRemoved {
block_hashes: Vec<BlockHash>,
#[serde(default)]
medium: Option<String>,
},
#[serde(rename = "AllBlocksCleared")]
AllBlocksCleared {},
}
/// Start ZMQ listener and process events into tracker /// Start ZMQ listener and process events into tracker
pub async fn start_simple_zmq_listener( pub async fn start_simple_zmq_listener(
endpoint: String, endpoint: String,
...@@ -266,18 +150,19 @@ async fn run_listener_loop( ...@@ -266,18 +150,19 @@ async fn run_listener_loop(
fn process_event( fn process_event(
tracker: &mut CacheStatusTracker, tracker: &mut CacheStatusTracker,
event: VllmRawEvent, event: RawKvEvent,
data_parallel_rank: Option<i32>, data_parallel_rank: Option<i32>,
engine_source: EventSource, engine_source: EventSource,
) { ) {
match event { match event {
VllmRawEvent::BlockStored { RawKvEvent::BlockStored {
block_hashes, block_hashes,
parent_block_hash, parent_block_hash,
token_ids, token_ids,
block_size, block_size,
medium, medium,
lora_name, lora_name,
.. // block_mm_infos not used in consolidator
} => { } => {
let storage_tier = medium let storage_tier = medium
.as_ref() .as_ref()
...@@ -294,32 +179,15 @@ fn process_event( ...@@ -294,32 +179,15 @@ fn process_event(
data_parallel_rank data_parallel_rank
); );
// Convert block_size from i32 to usize for chunking // block_size is already usize; guard against 0 to avoid chunks() panic
// SAFETY: Must validate block_size > 0 to prevent panic in chunks() if block_size == 0 {
let block_size_usize = match usize::try_from(block_size) { tracing::warn!("Invalid block_size 0 (must be positive), skipping event to avoid chunks() panic");
Ok(size) if size > 0 => size,
_ => {
tracing::warn!(
"Invalid block_size {} (must be positive), skipping event to avoid chunks() panic",
block_size
);
return; return;
} }
};
// Convert token_ids from i32 to u32 and split into chunks // token_ids is already Vec<u32>; split directly into per-block chunks
let token_ids_u32: Vec<u32> = token_ids let token_chunks: Vec<Vec<u32>> = token_ids
.into_iter() .chunks(block_size)
.filter_map(|t| {
u32::try_from(t).ok().or_else(|| {
tracing::warn!("Invalid token ID {}, skipping", t);
None
})
})
.collect();
let token_chunks: Vec<Vec<u32>> = token_ids_u32
.chunks(block_size_usize)
.map(|chunk| chunk.to_vec()) .map(|chunk| chunk.to_vec())
.collect(); .collect();
...@@ -332,40 +200,19 @@ fn process_event( ...@@ -332,40 +200,19 @@ fn process_event(
return; return;
} }
// Process each block with its corresponding token chunk // For batches, chain the blocks: each block's parent is the previous block
// For batches, chain the blocks: each block's parent is the previous block in the batch let mut current_parent = parent_block_hash.map(|h| h.into_u64().to_string());
let mut current_parent = parent_block_hash
.as_ref()
.and_then(|h| {
h.to_u64().or_else(|| {
tracing::warn!(
"Skipping parent block hash with unparsable string hash {:?}",
h
);
None
})
})
.map(|h| h.to_string());
for (i, block_hash) in block_hashes.iter().enumerate() { for (i, block_hash) in block_hashes.into_iter().enumerate() {
let block_tokens = token_chunks[i].clone(); let block_tokens = token_chunks[i].clone();
let block_hash_u64 = block_hash.into_u64();
// Skip blocks with invalid/unparsable hashes to avoid collisions
let Some(block_hash_u64) = block_hash.to_u64() else {
tracing::warn!(
"Skipping block with unparsable string hash {:?} (index {})",
block_hash,
i
);
continue;
};
tracker.handle_store( tracker.handle_store(
block_hash_u64.to_string(), block_hash_u64.to_string(),
engine_source, engine_source,
block_tokens, block_tokens,
current_parent.clone(), current_parent.clone(),
block_size_usize, block_size,
lora_name.clone(), lora_name.clone(),
Some(storage_tier), Some(storage_tier),
data_parallel_rank, data_parallel_rank,
...@@ -376,10 +223,7 @@ fn process_event( ...@@ -376,10 +223,7 @@ fn process_event(
} }
} }
VllmRawEvent::BlockRemoved { RawKvEvent::BlockRemoved { block_hashes, medium } => {
block_hashes,
medium,
} => {
let storage_tier = medium let storage_tier = medium
.as_ref() .as_ref()
.and_then(|m| StorageTier::from_vllm_medium(m)) .and_then(|m| StorageTier::from_vllm_medium(m))
...@@ -392,19 +236,11 @@ fn process_event( ...@@ -392,19 +236,11 @@ fn process_event(
); );
for block_hash in block_hashes { for block_hash in block_hashes {
// Skip blocks with invalid/unparsable hashes to avoid collisions tracker.handle_remove(&block_hash.into_u64().to_string(), engine_source);
let Some(block_hash_u64) = block_hash.to_u64() else {
tracing::warn!(
"Skipping removal of block with unparsable string hash {:?}",
block_hash
);
continue;
};
tracker.handle_remove(&block_hash_u64.to_string(), engine_source);
} }
} }
VllmRawEvent::AllBlocksCleared {} => { RawKvEvent::AllBlocksCleared => {
tracing::debug!("Processing AllBlocksCleared"); tracing::debug!("Processing AllBlocksCleared");
tracker.handle_clear_all(); tracker.handle_clear_all();
} }
......
...@@ -123,7 +123,6 @@ class LLMServerManager: ...@@ -123,7 +123,6 @@ class LLMServerManager:
"DYN_KVBM_METRICS_PORT": str(self.metrics_port), "DYN_KVBM_METRICS_PORT": str(self.metrics_port),
# Enable vLLM batch invariant for deterministic batching # Enable vLLM batch invariant for deterministic batching
"VLLM_BATCH_INVARIANT": "1", "VLLM_BATCH_INVARIANT": "1",
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
} }
) )
...@@ -154,6 +153,8 @@ class LLMServerManager: ...@@ -154,6 +153,8 @@ class LLMServerManager:
"--kv-transfer-config", "--kv-transfer-config",
'{"kv_connector":"DynamoConnector","kv_role":"kv_both", "kv_connector_module_path": "kvbm.vllm_integration.connector"}', '{"kv_connector":"DynamoConnector","kv_role":"kv_both", "kv_connector_module_path": "kvbm.vllm_integration.connector"}',
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"), os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--attention-config.backend",
"FLASH_ATTN",
"--max-model-len", "--max-model-len",
"8000", # required to fit on L4 GPU when using 8b model "8000", # required to fit on L4 GPU when using 8b model
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment