Unverified Commit 75fea787 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: make KV cache events and routing LoRA-aware (#6517)

parent 432fae67
......@@ -116,10 +116,10 @@ class ZmqKvEventPublisher:
token_ids: list[int],
num_block_tokens: list[int],
block_hashes: list[int],
lora_id: int = 0,
parent_hash: Optional[int] = None,
block_mm_infos: Optional[list[dict | None]] = None,
attention_dp_rank: int = 0,
lora_name: Optional[str] = None,
):
"""Publish a BlockStored event.
......@@ -139,8 +139,9 @@ class ZmqKvEventPublisher:
"parent_block_hash": parent_hash_signed,
"token_ids": token_ids,
"block_size": self.kv_block_size,
"lora_id": lora_id if lora_id != 0 else None,
}
if lora_name is not None:
event["lora_name"] = lora_name
# Add multimodal info if present
if block_mm_infos is not None:
......@@ -597,17 +598,14 @@ class Publisher:
else:
block_mm_infos.append(None)
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
lora_name = data.get("lora_name")
# Get attention_dp_rank from event (TRT-LLM includes this in KVCacheEvent)
# Default to 0 for backwards compatibility with older TRT-LLM versions
attention_dp_rank = event.get("attention_dp_rank", 0)
logging.debug(
f"publish stored event: engine_event_id: {event_id}, attention_dp_rank: {attention_dp_rank}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
f"publish stored event: engine_event_id: {event_id}, attention_dp_rank: {attention_dp_rank}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_name: {lora_name}, parent_hash: {parent_hash}"
)
# Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
# Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
......@@ -617,10 +615,10 @@ class Publisher:
token_ids,
num_block_tokens,
block_hashes,
lora_id,
parent_hash,
block_mm_infos,
attention_dp_rank,
lora_name,
)
elif self.kv_event_publishers:
# No consolidator: publish to NATS (router subscribes directly)
......@@ -631,9 +629,9 @@ class Publisher:
token_ids,
num_block_tokens,
block_hashes,
lora_id,
parent_hash,
block_mm_infos,
lora_name=lora_name,
)
else:
logging.warning(
......
......@@ -88,7 +88,7 @@ To get a feel for how KV Cache management works on a single worker with KV Cache
1. **Request tokenization**: The incoming prompt is converted into tokens
2. **Block partitioning**: The token sequence is divided into fixed-size blocks (e.g., 16 or 64 tokens per block)
3. **Block hashing**: Each block of tokens is hashed to create a unique identifier
3. **Block hashing**: Each block of tokens is hashed to create a unique identifier. When a LoRA adapter is active, the adapter name is incorporated into the hash so that blocks cached under different adapters produce distinct identifiers.
4. **Cache lookup**:
- For each block, the system checks if a matching block already exists in the KV cache
- If a match is found, the existing KV cache block is reused
......
......@@ -306,9 +306,20 @@ kubectl logs deployment/my-worker | grep -i lora
- Check that the LoRA is loaded on the worker handling your request
- For disaggregated serving, ensure both prefill and decode workers have the LoRA
## KV Cache-Aware LoRA Routing
When KV-aware routing is enabled, the router automatically accounts for LoRA adapter identity when computing block hashes. This means:
- **Distinct hash spaces per adapter**: Blocks cached under adapter `A` will never be confused with blocks cached under adapter `B` or the base model, even if the token sequences are identical. The adapter name is mixed into the `LocalBlockHash` computation.
- **Automatic prefix sharing within the same adapter**: Requests targeting the same LoRA adapter benefit from KV cache prefix matching just like base model requests do.
- **No configuration required**: The LoRA name is propagated automatically through KV events (`BlockStored`) from the engine to the router. The router uses the `lora_name` field on events to route LoRA requests to workers that have matching cached blocks.
This works end-to-end across the publisher pipeline, the KV consolidator (for deduplication), and the routing query path.
## See Also
- [Feature Matrix](../../reference/feature-matrix.md) - Backend compatibility overview
- [vLLM Backend](../../backends/vllm/README.md) - vLLM-specific configuration
- [Dynamo Operator](../../kubernetes/dynamo-operator.md) - Kubernetes operator overview
- [KV-Aware Routing](../../components/router/router-guide.md) - LoRA-aware request routing
- [KV Events for Custom Engines](../../integrations/kv-events-custom-engines.md) - Publishing LoRA-aware KV events
......@@ -41,7 +41,7 @@ For `BlockStored` events:
- **`block_hashes`**: List of **sequence block hashes** from the engine's block manager. These are cumulative hashes that incorporate all tokens from the start of the sequence up to and including the current block (not just the tokens within that block). This enables prefix matching across requests.
- **`num_block_tokens`**: Number of tokens per block (should all equal `kv_block_size`)
- **`parent_hash`**: Hash of the parent block. Required for all blocks except the first block in a sequence (which has no parent).
- **`lora_id`**: LoRA adapter ID (0 if not using LoRA)
- **`lora_name`**: LoRA adapter name string (omit or `None` for base model). When set, the adapter name is incorporated into block hash computation so that blocks for different LoRA adapters (or the base model) are never conflated.
For `BlockRemoved` events:
- **`block_hashes`**: List of sequence block hashes being evicted
......@@ -93,15 +93,16 @@ class CustomEnginePublisher:
)
def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int],
lora_id: int = 0, parent_hash: int | None = None):
parent_hash: int | None = None,
lora_name: str | None = None):
"""Call after KV cache blocks are allocated."""
num_block_tokens = [self.block_size] * len(block_hashes)
self.kv_publisher.publish_stored(
token_ids=token_ids,
num_block_tokens=num_block_tokens,
block_hashes=block_hashes,
lora_id=lora_id,
parent_hash=parent_hash,
lora_name=lora_name,
)
def on_blocks_removed(self, block_hashes: list[int]):
......@@ -209,7 +210,7 @@ For `BlockStored`:
"parent_block_hash": signed_i64 | None, # Parent hash
"token_ids": [int, ...], # Token IDs
"block_size": int, # Tokens per block
"lora_id": int | None, # LoRA adapter ID
"lora_name": str | None, # LoRA adapter name
}
```
......@@ -257,12 +258,13 @@ publish_stored(
token_ids: list[int],
num_block_tokens: list[int],
block_hashes: list[int],
lora_id: int,
parent_hash: int | None = None,
block_mm_infos: list[dict | None] | None = None,
lora_name: str | None = None,
)
```
Publish a block-stored event. Event IDs are managed internally.
Publish a block-stored event. Event IDs are managed internally. When `lora_name` is provided, the adapter name is mixed into block hash computation so blocks cached under different adapters produce distinct hashes.
#### `publish_removed()`
......
......@@ -212,12 +212,13 @@ fn kv_event_create_stored_block_from_parts(
token_ids: *const u32,
num_tokens: usize,
kv_block_size: u32,
_lora_id: u64,
lora_name: Option<&str>,
) -> KvCacheStoredBlockData {
let tokens_hash = compute_block_hash_for_seq(
unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
kv_block_size,
None,
lora_name,
)[0];
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash),
......@@ -264,7 +265,7 @@ fn kv_event_create_stored_from_parts(
tokens,
num_toks,
kv_block_size,
kv_params.lora_id,
kv_params.lora_name.as_deref(),
));
}
......@@ -303,12 +304,13 @@ pub struct DynamoKvStoredEventParams {
pub block_ids: *const u64,
pub num_blocks: usize,
pub parent_hash: Option<u64>,
pub lora_id: u64,
pub lora_name: Option<String>,
}
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash
/// has a parent hash or not. nullptr is used to represent no parent hash.
/// lora_name is an optional null-terminated C string; pass nullptr for base model.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
event_id: u64,
......@@ -317,7 +319,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
block_ids: *const u64,
num_blocks: usize,
parent_hash: *const u64,
lora_id: u64,
lora_name: *const c_char,
) -> DynamoLlmResult {
let parent_hash = {
if parent_hash.is_null() {
......@@ -326,6 +328,17 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
Some(unsafe { *parent_hash })
}
};
let lora_name = if lora_name.is_null() {
None
} else {
match unsafe { CStr::from_ptr(lora_name) }.to_str() {
Ok(s) => Some(s.to_owned()),
Err(e) => {
tracing::error!(error = ?e, "Failed to convert C string to Rust string (lora_name)");
return DynamoLlmResult::ERR;
}
}
};
let kv_params = DynamoKvStoredEventParams {
event_id,
token_ids,
......@@ -333,7 +346,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
block_ids,
num_blocks,
parent_hash,
lora_id,
lora_name,
};
let publisher = KV_PUB.get().unwrap();
let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
......@@ -789,7 +802,10 @@ pub unsafe extern "C" fn add_request(
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
// Compute overlap_blocks using the public method
let overlap_blocks = match decode_router.get_overlap_blocks(&tokens, worker).await {
let overlap_blocks = match decode_router
.get_overlap_blocks(&tokens, worker, None)
.await
{
Ok(overlap) => overlap,
Err(e) => {
tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
......
......@@ -47,12 +47,13 @@ pub fn start_kv_block_indexer_py<'p>(
}
#[pyfunction]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None))]
pub fn compute_block_hash_for_seq_py(
_py: Python,
tokens: Vec<u32>,
kv_block_size: usize,
block_mm_infos: Option<Bound<PyAny>>,
lora_name: Option<String>,
) -> PyResult<Vec<u64>> {
if kv_block_size == 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
......@@ -65,7 +66,12 @@ pub fn compute_block_hash_for_seq_py(
.map(depythonize_block_mm_infos)
.transpose()?;
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32, mm_infos.as_deref());
let hashes = compute_block_hash_for_seq(
&tokens,
kv_block_size as u32,
mm_infos.as_deref(),
lora_name.as_deref(),
);
Ok(hashes.into_iter().map(|h| h.0).collect())
}
......@@ -169,23 +175,22 @@ impl KvEventPublisher {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, block_mm_infos=None))]
#[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None))]
fn publish_stored(
&self,
py: Python,
token_ids: Vec<u32>,
num_block_tokens: Vec<u64>,
block_hashes: Vec<i64>,
lora_id: u64,
parent_hash: Option<i64>,
block_mm_infos: Option<Bound<PyAny>>,
lora_name: Option<String>,
) -> PyResult<()> {
let kv_block_size = self.kv_block_size as u32;
let dp_rank = self.dp_rank;
let warning_count = self.warning_count.clone();
let inner = self.inner.clone();
// Use shared monotonic event_id counter from the inner publisher
let event_id = inner.next_event_id();
let mm_infos = block_mm_infos
......@@ -204,7 +209,7 @@ impl KvEventPublisher {
&token_ids,
&num_block_tokens,
&block_hashes_u64,
lora_id,
lora_name.as_deref(),
&warning_count,
mm_infos.as_deref(),
),
......@@ -879,7 +884,7 @@ impl KvRouter {
Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
}
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None))]
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None, lora_name=None))]
fn best_worker<'p>(
&self,
py: Python<'p>,
......@@ -887,6 +892,7 @@ impl KvRouter {
router_config_override: Option<PyObject>,
request_id: Option<String>,
block_mm_infos: Option<PyObject>,
lora_name: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
let override_config: llm_rs::kv_router::RouterConfigOverride =
......@@ -911,7 +917,7 @@ impl KvRouter {
block_mm_infos.as_deref(),
router_config_override.as_ref(),
update_states,
None, // lora_name not exposed in Python API yet
lora_name,
0.0,
)
.await
......@@ -948,16 +954,18 @@ impl KvRouter {
})
}
#[pyo3(signature = (token_ids, lora_name=None))]
fn get_potential_loads<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
lora_name: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
let chooser = self.inner.chooser.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let loads = chooser
.get_potential_loads(&token_ids, None)
.get_potential_loads(&token_ids, None, lora_name.as_deref())
.await
.map_err(to_pyerr)?;
......
......@@ -236,7 +236,8 @@ class ModelCardInstanceId:
def compute_block_hash_for_seq(
tokens: List[int],
kv_block_size: int,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
) -> List[int]:
"""
Compute block hashes for a sequence of tokens, optionally including multimodal metadata.
......@@ -562,7 +563,7 @@ class KvIndexer:
...
def find_matches_for_request(
self, token_ids: List[int], lora_id: int
self, token_ids: List[int], lora_name: Optional[str] = None
) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
......@@ -610,13 +611,14 @@ class ApproxKvIndexer:
...
def find_matches_for_request(
self, token_ids: List[int]
self, token_ids: List[int], lora_name: Optional[str] = None
) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
Args:
token_ids: List of token IDs to find matches for
lora_name: Optional LoRA adapter name for adapter-aware matching
Returns:
OverlapScores containing worker matching scores and frequencies
......@@ -689,9 +691,9 @@ class KvEventPublisher:
token_ids: List[int],
num_block_tokens: List[int],
block_hashes: List[int],
lora_id: int,
parent_hash: Optional[int] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
) -> None:
"""
Publish a KV stored event.
......@@ -702,11 +704,11 @@ class KvEventPublisher:
token_ids: List of token IDs
num_block_tokens: Number of tokens per block
block_hashes: List of block hashes (signed 64-bit integers)
lora_id: The LoRA ID
parent_hash: Optional parent hash (signed 64-bit integer)
block_mm_infos: Optional list of multimodal info for each block.
Each item is either None or a dict with "mm_objects" key containing
a list of {"mm_hash": int, "offsets": [[start, end], ...]} dicts.
lora_name: Optional LoRA adapter name for adapter-aware block hashing.
"""
...
......@@ -1388,6 +1390,7 @@ class KvRouter:
router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
) -> Tuple[int, int, int]:
"""
Find the best matching worker for the given tokens.
......@@ -1413,6 +1416,7 @@ class KvRouter:
async def get_potential_loads(
self,
token_ids: List[int],
lora_name: Optional[str] = None,
) -> List[Dict[str, int]]:
"""
Get potential prefill and decode loads for all workers.
......
......@@ -10,13 +10,14 @@ Every cached KV block in a distributed LLM system needs four pieces of informati
### 1. Local Block Hash (`LocalBlockHash`, u64)
**What**: Hash of the tokens *within* a single block (e.g., 64 tokens).
**What**: Hash of the tokens *within* a single block (e.g., 64 tokens), optionally including LoRA adapter name and multimodal metadata.
**Why**: Identifies the content of this specific block, independent of context. Two blocks with the same tokens have the same local hash.
**Why**: Identifies the content of this specific block, independent of context. Two blocks with the same tokens (and same LoRA adapter) have the same local hash. When a LoRA adapter name is provided, it is length-prefixed and appended to the byte buffer before hashing, ensuring that blocks under different adapters (or the base model) always produce distinct hashes.
```
```text
Block at position 5: tokens [101, 102, 103, ...]
LocalBlockHash = hash(tokens) = 0xABCD1234
LocalBlockHash = hash(tokens) = 0xABCD1234 (base model)
LocalBlockHash = hash(tokens || len("my-lora") || "my-lora") = 0xDEAD5678 (LoRA adapter)
```
### 2. External Sequence Block Hash (`ExternalSequenceBlockHash`, u64)
......@@ -39,6 +40,8 @@ block2 in B: seq_hash = hash(hash(hash(block0') || block1') || block2) = 0x2222
>
> In practice, the `ExternalSequenceBlockHash` may come directly from the inference engine (e.g., vLLM, TensorRT-LLM) using a rolling hash algorithm that we don't know or control. The engine computes these hashes internally and reports them via KV cache events.
>
> **LoRA identity**: The engine is responsible for incorporating the LoRA adapter identity into the `ExternalSequenceBlockHash` before emitting KV events. Dynamo does not add LoRA information at the router layer. For example, vLLM does this via `_gen_lora_extra_hash_keys`, which appends the LoRA ID as extra keys when calling `hash_block_tokens(..., extra_keys)`. Any engine integrating with the KV router must follow the same convention to ensure correct cache isolation between LoRA adapters.
>
> **Implications for index implementations:**
>
> - **RadixTree**: Can handle engine-provided hashes because it traverses the tree structure using `LocalBlockHash` for navigation and only uses `ExternalSequenceBlockHash` as an opaque identifier for lookups. It doesn't need to recompute hashes.
......
......@@ -302,7 +302,7 @@ fn bench_hash(args: &Args) {
for (i, tokens) in token_sequences.iter().enumerate() {
let start = Instant::now();
let _ = compute_block_hash_for_seq(tokens, args.block_size, None);
let _ = compute_block_hash_for_seq(tokens, args.block_size, None, None);
let elapsed = start.elapsed();
if i >= warmup_iters {
......
......@@ -349,7 +349,7 @@ mod tests {
// 1. Before routing decision there should be no matches
let pre_scores = indexer
.find_matches_for_request(&tokens)
.find_matches_for_request(&tokens, None)
.await
.expect("indexer offline");
assert!(pre_scores.scores.is_empty());
......@@ -366,7 +366,10 @@ mod tests {
// Poll until we observe the match being registered
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_id))
.copied()
......@@ -376,7 +379,10 @@ mod tests {
// 3. After the TTL has passed the entry should expire automatically
time::sleep(TTL + Duration::from_millis(50)).await;
let post_scores = indexer.find_matches_for_request(&tokens).await.unwrap();
let post_scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert!(post_scores.scores.is_empty());
}
......@@ -413,7 +419,10 @@ mod tests {
// Wait until the worker is registered
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
})
......@@ -424,7 +433,10 @@ mod tests {
// Ensure the worker's entries are gone
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
!s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
})
......@@ -475,7 +487,10 @@ mod tests {
// Ensure both workers are registered
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
......@@ -492,7 +507,10 @@ mod tests {
// Confirm the removed worker is gone, and the other remains.
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
!s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
&& s.scores
......@@ -539,7 +557,10 @@ mod tests {
// Ensure the indexer has registered the block
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&seq_a).await.unwrap();
let s = indexer
.find_matches_for_request(&seq_a, None)
.await
.unwrap();
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a))
.copied()
......@@ -551,7 +572,10 @@ mod tests {
let seq_b: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
// Query the indexer for overlaps of Sequence B (before it has been routed anywhere)
let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap();
let overlap = indexer
.find_matches_for_request(&seq_b, None)
.await
.unwrap();
// Expect worker A to have an overlap score of 1 (shared first block)
assert_eq!(
......@@ -606,7 +630,10 @@ mod tests {
// Wait until both workers are reflected in overlap scores
spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap();
let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied()
......@@ -618,7 +645,10 @@ mod tests {
})
.await;
let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
let scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert_eq!(
scores
......@@ -777,7 +807,10 @@ mod tests {
// Verify all 5 blocks are present (no pruning yet)
for i in 0..5 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
let scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert_eq!(
scores.scores.get(&worker).copied(),
Some(1),
......@@ -803,7 +836,10 @@ mod tests {
// Verify that the 4 oldest blocks are pruned
for i in 0..4 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
let scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert!(
scores.scores.get(&worker).copied().unwrap_or(0) == 0,
"Block {} should have been pruned but is still present",
......@@ -814,7 +850,10 @@ mod tests {
// Verify the 2 newest blocks are present
for i in 4..6 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
let scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert_eq!(
scores.scores.get(&worker).copied(),
Some(1),
......
......@@ -1122,7 +1122,9 @@ mod tests {
tokio::time::sleep(Duration::from_millis(100)).await;
let scores = indexer.find_matches_for_request(&[100, 200, 300]).await;
let scores = indexer
.find_matches_for_request(&[100, 200, 300], None)
.await;
assert!(scores.is_ok());
indexer.shutdown();
......
......@@ -302,6 +302,7 @@ pub trait KvIndexerInterface {
/// ### Arguments
///
/// * `tokens` - A vector of `u32` tokens.
/// * `lora_name` - Optional LoRA adapter name to include in block hash computation.
///
/// ### Returns
///
......@@ -309,6 +310,7 @@ pub trait KvIndexerInterface {
async fn find_matches_for_request(
&self,
tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError>;
/// Apply a `RouterEvent` to the KV store.
......@@ -510,8 +512,9 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
async fn find_matches_for_request(
&self,
tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name);
Ok(self.backend.find_matches(&sequence, false))
}
......@@ -972,13 +975,14 @@ impl KvIndexerInterface for KvIndexer {
async fn find_matches_for_request(
&self,
tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> {
tracing::debug!(
"Finding matches for request tokens: {:?} / len: {}",
tokens,
tokens.len()
);
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name);
tracing::debug!("Computed sequence: {:?}", sequence);
self.find_matches(sequence).await
}
......@@ -1307,8 +1311,11 @@ impl KvIndexerInterface for LocalKvIndexer {
async fn find_matches_for_request(
&self,
tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> {
self.indexer.find_matches_for_request(tokens).await
self.indexer
.find_matches_for_request(tokens, lora_name)
.await
}
async fn apply_event(&self, event: RouterEvent) {
......@@ -1760,8 +1767,9 @@ impl KvIndexerInterface for KvIndexerSharded {
async fn find_matches_for_request(
&self,
tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name);
self.find_matches(sequence).await
}
......@@ -2350,7 +2358,7 @@ mod tests {
// Empty index should return no matches
let tokens = vec![1, 2, 3, 4];
let scores = index.find_matches_for_request(&tokens).await.unwrap();
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
assert!(scores.scores.is_empty());
// Store some data and verify we can find it via tokens
......@@ -2362,7 +2370,7 @@ mod tests {
// Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error.
let scores = index.find_matches_for_request(&tokens).await.unwrap();
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert!(scores.scores.is_empty() || !scores.scores.is_empty());
......
......@@ -174,6 +174,7 @@ impl KvIndexerInterface for NaiveNestedMap {
async fn find_matches_for_request(
&self,
_tokens: &[u32],
_lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> {
unimplemented!("not used in bench")
}
......@@ -384,6 +385,7 @@ impl KvIndexerInterface for InvertedIndex {
async fn find_matches_for_request(
&self,
_tokens: &[u32],
_lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> {
unimplemented!("not used in bench")
}
......
......@@ -19,23 +19,34 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(compute_hash(data))
}
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata.
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata
/// and LoRA adapter identity.
///
/// When multimodal extra info is provided, the mm_hashes are included in the hash computation
/// to ensure that blocks with identical tokens but different multimodal objects produce
/// different hashes.
///
/// When `lora_name` is provided, the adapter name is mixed into the XXH3 seed so that
/// blocks cached under different LoRA adapters (or the base model) produce distinct hashes.
/// Because LoRA identity applies uniformly to every block in a sequence, encoding it in the
/// seed is more efficient than appending per-block bytes and matches the approach used by
/// KVBM's `SaltHash`.
pub fn compute_block_hash_for_seq(
tokens: &[u32],
kv_block_size: u32,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
lora_name: Option<&str>,
) -> Vec<LocalBlockHash> {
let seed = match lora_name.filter(|n| !n.is_empty()) {
Some(name) => XXH3_SEED.wrapping_add(xxh3::xxh3_64(name.as_bytes())),
None => XXH3_SEED,
};
tokens
.chunks_exact(kv_block_size as usize)
.enumerate()
.map(|(block_idx, chunk)| {
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
// Include MM hashes in the block hash computation if present
if let Some(mm_infos) = block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{
......@@ -51,7 +62,7 @@ pub fn compute_block_hash_for_seq(
}
}
compute_block_hash(&bytes)
LocalBlockHash(xxh3::xxh3_64_with_seed(&bytes, seed))
})
.collect()
}
......@@ -176,13 +187,13 @@ pub struct ActiveLoad {
pub active_prefill_tokens: Option<u64>,
}
/// A [`LocalBlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional
/// lora_id of a block.
/// A [`LocalBlockHash`] is a hash computed from the token IDs, optional multimodal metadata,
/// and optional LoRA adapter name of a block.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct LocalBlockHash(pub u64);
/// A sequence aware hash of a block where the hash is computed from the tokens_ids, extra_token_ids
/// and the optional lora_id of a block, PLUS the hash of the parent block.
/// A sequence-aware hash of a block computed by the engine from token IDs, optional metadata,
/// and the hash of the parent block.
///
/// In this case, the hashing function is external and unknown.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
......@@ -575,6 +586,7 @@ pub struct TokensWithHashes {
tokens: Vec<u32>,
block_size: u32,
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
lora_name: Option<String>,
block_hashes: Option<Vec<LocalBlockHash>>,
seq_hashes: Option<Vec<SequenceHash>>,
}
......@@ -586,6 +598,7 @@ impl TokensWithHashes {
tokens,
block_size,
block_mm_infos: None,
lora_name: None,
block_hashes: None,
seq_hashes: None,
}
......@@ -597,6 +610,12 @@ impl TokensWithHashes {
self
}
/// Sets the LoRA adapter name for hash computation.
pub fn with_lora_name(mut self, name: String) -> Self {
self.lora_name = Some(name);
self
}
/// Returns a reference to the tokens.
pub fn tokens(&self) -> &[u32] {
&self.tokens
......@@ -629,6 +648,7 @@ impl TokensWithHashes {
&self.tokens,
self.block_size,
self.block_mm_infos.as_deref(),
self.lora_name.as_deref(),
));
}
self.block_hashes.as_ref().unwrap()
......@@ -708,22 +728,67 @@ mod tests {
#[case(32)]
#[case(64)]
fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
// create a sequence of kv_block_size elements
let sequence = (0..kv_block_size).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None);
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None);
assert_eq!(hashes.len(), 1);
// create a sequence of kv_block_size + 1 elements
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None);
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None);
assert_eq!(hashes.len(), 1);
// create a sequence of 2 * kv_block_size + 1 elements
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None);
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None);
assert_eq!(hashes.len(), 2);
}
#[test]
fn test_lora_name_produces_different_hash() {
let tokens: Vec<u32> = (0..4).collect();
let base = compute_block_hash_for_seq(&tokens, 4, None, None);
let lora_a = compute_block_hash_for_seq(&tokens, 4, None, Some("adapter-a"));
let lora_b = compute_block_hash_for_seq(&tokens, 4, None, Some("adapter-b"));
assert_ne!(base[0], lora_a[0]);
assert_ne!(base[0], lora_b[0]);
assert_ne!(lora_a[0], lora_b[0]);
}
#[test]
fn test_lora_name_none_matches_legacy() {
let tokens: Vec<u32> = (0..8).collect();
let hashes_none = compute_block_hash_for_seq(&tokens, 4, None, None);
let hashes_none2 = compute_block_hash_for_seq(&tokens, 4, None, None);
assert_eq!(hashes_none, hashes_none2);
}
#[test]
fn test_lora_name_empty_string_normalized_to_none() {
let tokens: Vec<u32> = (0..4).collect();
let base = compute_block_hash_for_seq(&tokens, 4, None, None);
let empty = compute_block_hash_for_seq(&tokens, 4, None, Some(""));
assert_eq!(
base, empty,
"empty lora_name should be treated as base model"
);
}
#[test]
fn test_tokens_with_hashes_lora() {
let tokens: Vec<u32> = (0..8).collect();
let mut base = TokensWithHashes::new(tokens.clone(), 4);
let base_hashes = base.get_or_compute_block_hashes().to_vec();
let mut with_lora =
TokensWithHashes::new(tokens, 4).with_lora_name("my-adapter".to_string());
let lora_hashes = with_lora.get_or_compute_block_hashes().to_vec();
assert_eq!(base_hashes.len(), lora_hashes.len());
for (b, l) in base_hashes.iter().zip(lora_hashes.iter()) {
assert_ne!(b, l);
}
}
#[test]
fn test_local_block_hash_serialization() {
let hash = LocalBlockHash(12345);
......
......@@ -223,7 +223,7 @@ impl DynamoEventManager {
tokens,
parent_hash,
block_size,
None, // lora_id
None, // lora_name
None, // tier
None, // data_parallel_rank
)
......
......@@ -40,7 +40,7 @@ impl KvEventConsolidatorHandle {
token_ids: Vec<u32>,
parent_hash: Option<String>,
block_size: usize,
lora_id: Option<u64>,
lora_name: Option<String>,
tier: Option<StorageTier>,
data_parallel_rank: Option<i32>,
) {
......@@ -51,7 +51,7 @@ impl KvEventConsolidatorHandle {
token_ids,
parent_hash,
block_size,
lora_id.map(|id| id as i32),
lora_name,
tier,
data_parallel_rank,
);
......
......@@ -31,7 +31,6 @@ struct EventBatch(
);
/// Event types matching vLLM's format
/// Note: Uses i32 for token_ids, block_size, and lora_id to match vLLM's ZmqEventPublisher
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum Event {
......@@ -41,7 +40,8 @@ enum Event {
parent_block_hash: Option<u64>,
token_ids: Vec<i32>,
block_size: i32,
lora_id: Option<i32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
lora_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
medium: Option<String>,
},
......@@ -68,15 +68,13 @@ impl Event {
parent_hash,
token_ids,
block_size,
lora_id,
source: _, // Source used for logging only, not sent to router
lora_name,
source: _,
} => {
// Parse block hash - fail if invalid to prevent corruption
let parsed_hash = block_hash
.parse::<u64>()
.with_context(|| format!("Failed to parse block_hash: {}", block_hash))?;
// Parse parent hash if present - fail if invalid
let parsed_parent = parent_hash
.map(|h| {
h.parse::<u64>()
......@@ -84,8 +82,6 @@ impl Event {
})
.transpose()?;
// Convert u32 token_ids to i32 for vLLM compatibility
// Token IDs should never exceed i32::MAX in practice, but we handle it gracefully
let token_ids_i32: Vec<i32> = token_ids
.into_iter()
.map(|t| {
......@@ -96,7 +92,6 @@ impl Event {
})
.collect();
// Convert usize block_size to i32 for vLLM compatibility
let block_size_i32 = i32::try_from(block_size).unwrap_or_else(|_| {
tracing::warn!(
"Block size {} exceeds i32::MAX, clamping to i32::MAX",
......@@ -105,16 +100,13 @@ impl Event {
i32::MAX
});
// lora_id is already Option<i32> in ConsolidatedEvent::Store
let lora_id_i32 = lora_id;
Ok(Event::BlockStored {
block_hashes: vec![parsed_hash],
parent_block_hash: parsed_parent,
token_ids: token_ids_i32,
block_size: block_size_i32,
lora_id: lora_id_i32,
medium: None, // Not provided by ConsolidatedEvent
lora_name,
medium: None,
})
}
ConsolidatedEvent::Remove {
......
......@@ -145,12 +145,9 @@ enum VllmRawEvent {
parent_block_hash: Option<BlockHash>,
token_ids: Vec<i32>,
block_size: i32,
lora_id: Option<i32>,
#[serde(default)]
medium: Option<String>,
#[serde(default)]
#[allow(dead_code)]
// Reserved for future use, needed for vLLM 0.14.0 deserialization
lora_name: Option<String>,
},
#[serde(rename = "BlockRemoved")]
......@@ -279,9 +276,8 @@ fn process_event(
parent_block_hash,
token_ids,
block_size,
lora_id,
medium,
lora_name: _, // Not used yet, lora_id is still used for backwards compat
lora_name,
} => {
let storage_tier = medium
.as_ref()
......@@ -370,7 +366,7 @@ fn process_event(
block_tokens,
current_parent.clone(),
block_size_usize,
lora_id,
lora_name.clone(),
Some(storage_tier),
data_parallel_rank,
);
......
......@@ -188,8 +188,8 @@ pub enum ConsolidatedEvent {
parent_hash: Option<String>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<i32>,
source: String, // The source where it was first stored (vllm or kvbm)
lora_name: Option<String>,
source: String,
},
/// Block removed (removed from all sources)
Remove {
......@@ -257,7 +257,7 @@ impl CacheStatusTracker {
token_ids: Vec<u32>,
parent_hash: Option<String>,
block_size: usize,
lora_id: Option<i32>,
lora_name: Option<String>,
tier: Option<StorageTier>,
data_parallel_rank: Option<i32>,
) -> bool {
......@@ -327,7 +327,7 @@ impl CacheStatusTracker {
.as_ref()
.map(|p| &p[..16.min(p.len())])
.unwrap_or("none"),
lora_id,
lora_name,
data_parallel_rank,
&token_ids
);
......@@ -366,7 +366,7 @@ impl CacheStatusTracker {
parent_hash: resolved_parent_hash,
token_ids,
block_size,
lora_id,
lora_name,
source: source.to_str().to_string(),
});
......
......@@ -370,8 +370,14 @@ impl KvRouter {
let isl_tokens = tokens.len();
let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
.in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, block_mm_infos));
let block_hashes = tracing::info_span!("kv_router.compute_block_hashes").in_scope(|| {
compute_block_hash_for_seq(
tokens,
self.block_size,
block_mm_infos,
lora_name.as_deref(),
)
});
let hash_elapsed = start.elapsed();
let overlap_scores = self
......@@ -381,12 +387,12 @@ impl KvRouter {
.await?;
let find_matches_elapsed = start.elapsed();
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
self.block_size,
router_config_override,
lora_name.as_deref(),
)
});
let seq_hash_elapsed = start.elapsed();
......@@ -453,6 +459,7 @@ impl KvRouter {
tokens,
self.block_size,
router_config_override,
lora_name.as_deref(),
);
if let Err(e) = self
......@@ -506,8 +513,9 @@ impl KvRouter {
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
lora_name: Option<&str>,
) -> Result<u32, KvRouterError> {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None, lora_name);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
}
......@@ -517,15 +525,17 @@ impl KvRouter {
&self,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
lora_name: Option<&str>,
) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None, lora_name);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
self.block_size,
router_config_override,
lora_name,
);
Ok(self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment