Unverified Commit 75fea787 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: make KV cache events and routing LoRA-aware (#6517)

parent 432fae67
...@@ -116,10 +116,10 @@ class ZmqKvEventPublisher: ...@@ -116,10 +116,10 @@ class ZmqKvEventPublisher:
token_ids: list[int], token_ids: list[int],
num_block_tokens: list[int], num_block_tokens: list[int],
block_hashes: list[int], block_hashes: list[int],
lora_id: int = 0,
parent_hash: Optional[int] = None, parent_hash: Optional[int] = None,
block_mm_infos: Optional[list[dict | None]] = None, block_mm_infos: Optional[list[dict | None]] = None,
attention_dp_rank: int = 0, attention_dp_rank: int = 0,
lora_name: Optional[str] = None,
): ):
"""Publish a BlockStored event. """Publish a BlockStored event.
...@@ -139,8 +139,9 @@ class ZmqKvEventPublisher: ...@@ -139,8 +139,9 @@ class ZmqKvEventPublisher:
"parent_block_hash": parent_hash_signed, "parent_block_hash": parent_hash_signed,
"token_ids": token_ids, "token_ids": token_ids,
"block_size": self.kv_block_size, "block_size": self.kv_block_size,
"lora_id": lora_id if lora_id != 0 else None,
} }
if lora_name is not None:
event["lora_name"] = lora_name
# Add multimodal info if present # Add multimodal info if present
if block_mm_infos is not None: if block_mm_infos is not None:
...@@ -597,17 +598,14 @@ class Publisher: ...@@ -597,17 +598,14 @@ class Publisher:
else: else:
block_mm_infos.append(None) block_mm_infos.append(None)
# Note: Currently data does not have lora_id. lora_name = data.get("lora_name")
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id = data.get("lora_id", 0)
# Get attention_dp_rank from event (TRT-LLM includes this in KVCacheEvent) # Get attention_dp_rank from event (TRT-LLM includes this in KVCacheEvent)
# Default to 0 for backwards compatibility with older TRT-LLM versions # Default to 0 for backwards compatibility with older TRT-LLM versions
attention_dp_rank = event.get("attention_dp_rank", 0) attention_dp_rank = event.get("attention_dp_rank", 0)
logging.debug( logging.debug(
f"publish stored event: engine_event_id: {event_id}, attention_dp_rank: {attention_dp_rank}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}" f"publish stored event: engine_event_id: {event_id}, attention_dp_rank: {attention_dp_rank}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_name: {lora_name}, parent_hash: {parent_hash}"
) )
# Publish to ZMQ if consolidator is enabled, otherwise publish to NATS # Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
# Note: event_id is managed internally by the publisher (monotonic counter per dp_rank) # Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
...@@ -617,10 +615,10 @@ class Publisher: ...@@ -617,10 +615,10 @@ class Publisher:
token_ids, token_ids,
num_block_tokens, num_block_tokens,
block_hashes, block_hashes,
lora_id,
parent_hash, parent_hash,
block_mm_infos, block_mm_infos,
attention_dp_rank, attention_dp_rank,
lora_name,
) )
elif self.kv_event_publishers: elif self.kv_event_publishers:
# No consolidator: publish to NATS (router subscribes directly) # No consolidator: publish to NATS (router subscribes directly)
...@@ -631,9 +629,9 @@ class Publisher: ...@@ -631,9 +629,9 @@ class Publisher:
token_ids, token_ids,
num_block_tokens, num_block_tokens,
block_hashes, block_hashes,
lora_id,
parent_hash, parent_hash,
block_mm_infos, block_mm_infos,
lora_name=lora_name,
) )
else: else:
logging.warning( logging.warning(
......
...@@ -88,7 +88,7 @@ To get a feel for how KV Cache management works on a single worker with KV Cache ...@@ -88,7 +88,7 @@ To get a feel for how KV Cache management works on a single worker with KV Cache
1. **Request tokenization**: The incoming prompt is converted into tokens 1. **Request tokenization**: The incoming prompt is converted into tokens
2. **Block partitioning**: The token sequence is divided into fixed-size blocks (e.g., 16 or 64 tokens per block) 2. **Block partitioning**: The token sequence is divided into fixed-size blocks (e.g., 16 or 64 tokens per block)
3. **Block hashing**: Each block of tokens is hashed to create a unique identifier 3. **Block hashing**: Each block of tokens is hashed to create a unique identifier. When a LoRA adapter is active, the adapter name is incorporated into the hash so that blocks cached under different adapters produce distinct identifiers.
4. **Cache lookup**: 4. **Cache lookup**:
- For each block, the system checks if a matching block already exists in the KV cache - For each block, the system checks if a matching block already exists in the KV cache
- If a match is found, the existing KV cache block is reused - If a match is found, the existing KV cache block is reused
......
...@@ -306,9 +306,20 @@ kubectl logs deployment/my-worker | grep -i lora ...@@ -306,9 +306,20 @@ kubectl logs deployment/my-worker | grep -i lora
- Check that the LoRA is loaded on the worker handling your request - Check that the LoRA is loaded on the worker handling your request
- For disaggregated serving, ensure both prefill and decode workers have the LoRA - For disaggregated serving, ensure both prefill and decode workers have the LoRA
## KV Cache-Aware LoRA Routing
When KV-aware routing is enabled, the router automatically accounts for LoRA adapter identity when computing block hashes. This means:
- **Distinct hash spaces per adapter**: Blocks cached under adapter `A` will never be confused with blocks cached under adapter `B` or the base model, even if the token sequences are identical. The adapter name is mixed into the `LocalBlockHash` computation.
- **Automatic prefix sharing within the same adapter**: Requests targeting the same LoRA adapter benefit from KV cache prefix matching just like base model requests do.
- **No configuration required**: The LoRA name is propagated automatically through KV events (`BlockStored`) from the engine to the router. The router uses the `lora_name` field on events to route LoRA requests to workers that have matching cached blocks.
This works end-to-end across the publisher pipeline, the KV consolidator (for deduplication), and the routing query path.
## See Also ## See Also
- [Feature Matrix](../../reference/feature-matrix.md) - Backend compatibility overview - [Feature Matrix](../../reference/feature-matrix.md) - Backend compatibility overview
- [vLLM Backend](../../backends/vllm/README.md) - vLLM-specific configuration - [vLLM Backend](../../backends/vllm/README.md) - vLLM-specific configuration
- [Dynamo Operator](../../kubernetes/dynamo-operator.md) - Kubernetes operator overview - [Dynamo Operator](../../kubernetes/dynamo-operator.md) - Kubernetes operator overview
- [KV-Aware Routing](../../components/router/router-guide.md) - LoRA-aware request routing - [KV-Aware Routing](../../components/router/router-guide.md) - LoRA-aware request routing
- [KV Events for Custom Engines](../../integrations/kv-events-custom-engines.md) - Publishing LoRA-aware KV events
...@@ -41,7 +41,7 @@ For `BlockStored` events: ...@@ -41,7 +41,7 @@ For `BlockStored` events:
- **`block_hashes`**: List of **sequence block hashes** from the engine's block manager. These are cumulative hashes that incorporate all tokens from the start of the sequence up to and including the current block (not just the tokens within that block). This enables prefix matching across requests. - **`block_hashes`**: List of **sequence block hashes** from the engine's block manager. These are cumulative hashes that incorporate all tokens from the start of the sequence up to and including the current block (not just the tokens within that block). This enables prefix matching across requests.
- **`num_block_tokens`**: Number of tokens per block (should all equal `kv_block_size`) - **`num_block_tokens`**: Number of tokens per block (should all equal `kv_block_size`)
- **`parent_hash`**: Hash of the parent block. Required for all blocks except the first block in a sequence (which has no parent). - **`parent_hash`**: Hash of the parent block. Required for all blocks except the first block in a sequence (which has no parent).
- **`lora_id`**: LoRA adapter ID (0 if not using LoRA) - **`lora_name`**: LoRA adapter name string (omit or `None` for base model). When set, the adapter name is incorporated into block hash computation so that blocks for different LoRA adapters (or the base model) are never conflated.
For `BlockRemoved` events: For `BlockRemoved` events:
- **`block_hashes`**: List of sequence block hashes being evicted - **`block_hashes`**: List of sequence block hashes being evicted
...@@ -93,15 +93,16 @@ class CustomEnginePublisher: ...@@ -93,15 +93,16 @@ class CustomEnginePublisher:
) )
def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int], def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int],
lora_id: int = 0, parent_hash: int | None = None): parent_hash: int | None = None,
lora_name: str | None = None):
"""Call after KV cache blocks are allocated.""" """Call after KV cache blocks are allocated."""
num_block_tokens = [self.block_size] * len(block_hashes) num_block_tokens = [self.block_size] * len(block_hashes)
self.kv_publisher.publish_stored( self.kv_publisher.publish_stored(
token_ids=token_ids, token_ids=token_ids,
num_block_tokens=num_block_tokens, num_block_tokens=num_block_tokens,
block_hashes=block_hashes, block_hashes=block_hashes,
lora_id=lora_id,
parent_hash=parent_hash, parent_hash=parent_hash,
lora_name=lora_name,
) )
def on_blocks_removed(self, block_hashes: list[int]): def on_blocks_removed(self, block_hashes: list[int]):
...@@ -209,7 +210,7 @@ For `BlockStored`: ...@@ -209,7 +210,7 @@ For `BlockStored`:
"parent_block_hash": signed_i64 | None, # Parent hash "parent_block_hash": signed_i64 | None, # Parent hash
"token_ids": [int, ...], # Token IDs "token_ids": [int, ...], # Token IDs
"block_size": int, # Tokens per block "block_size": int, # Tokens per block
"lora_id": int | None, # LoRA adapter ID "lora_name": str | None, # LoRA adapter name
} }
``` ```
...@@ -257,12 +258,13 @@ publish_stored( ...@@ -257,12 +258,13 @@ publish_stored(
token_ids: list[int], token_ids: list[int],
num_block_tokens: list[int], num_block_tokens: list[int],
block_hashes: list[int], block_hashes: list[int],
lora_id: int,
parent_hash: int | None = None, parent_hash: int | None = None,
block_mm_infos: list[dict | None] | None = None,
lora_name: str | None = None,
) )
``` ```
Publish a block-stored event. Event IDs are managed internally. Publish a block-stored event. Event IDs are managed internally. When `lora_name` is provided, the adapter name is mixed into block hash computation so blocks cached under different adapters produce distinct hashes.
#### `publish_removed()` #### `publish_removed()`
......
...@@ -212,12 +212,13 @@ fn kv_event_create_stored_block_from_parts( ...@@ -212,12 +212,13 @@ fn kv_event_create_stored_block_from_parts(
token_ids: *const u32, token_ids: *const u32,
num_tokens: usize, num_tokens: usize,
kv_block_size: u32, kv_block_size: u32,
_lora_id: u64, lora_name: Option<&str>,
) -> KvCacheStoredBlockData { ) -> KvCacheStoredBlockData {
let tokens_hash = compute_block_hash_for_seq( let tokens_hash = compute_block_hash_for_seq(
unsafe { std::slice::from_raw_parts(token_ids, num_tokens) }, unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
kv_block_size, kv_block_size,
None, None,
lora_name,
)[0]; )[0];
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash), block_hash: ExternalSequenceBlockHash(block_hash),
...@@ -264,7 +265,7 @@ fn kv_event_create_stored_from_parts( ...@@ -264,7 +265,7 @@ fn kv_event_create_stored_from_parts(
tokens, tokens,
num_toks, num_toks,
kv_block_size, kv_block_size,
kv_params.lora_id, kv_params.lora_name.as_deref(),
)); ));
} }
...@@ -303,12 +304,13 @@ pub struct DynamoKvStoredEventParams { ...@@ -303,12 +304,13 @@ pub struct DynamoKvStoredEventParams {
pub block_ids: *const u64, pub block_ids: *const u64,
pub num_blocks: usize, pub num_blocks: usize,
pub parent_hash: Option<u64>, pub parent_hash: Option<u64>,
pub lora_id: u64, pub lora_name: Option<String>,
} }
/// # Safety /// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks /// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash /// has a parent hash or not. nullptr is used to represent no parent hash.
/// lora_name is an optional null-terminated C string; pass nullptr for base model.
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_kv_event_publish_stored( pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
event_id: u64, event_id: u64,
...@@ -317,7 +319,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( ...@@ -317,7 +319,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
block_ids: *const u64, block_ids: *const u64,
num_blocks: usize, num_blocks: usize,
parent_hash: *const u64, parent_hash: *const u64,
lora_id: u64, lora_name: *const c_char,
) -> DynamoLlmResult { ) -> DynamoLlmResult {
let parent_hash = { let parent_hash = {
if parent_hash.is_null() { if parent_hash.is_null() {
...@@ -326,6 +328,17 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( ...@@ -326,6 +328,17 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
Some(unsafe { *parent_hash }) Some(unsafe { *parent_hash })
} }
}; };
let lora_name = if lora_name.is_null() {
None
} else {
match unsafe { CStr::from_ptr(lora_name) }.to_str() {
Ok(s) => Some(s.to_owned()),
Err(e) => {
tracing::error!(error = ?e, "Failed to convert C string to Rust string (lora_name)");
return DynamoLlmResult::ERR;
}
}
};
let kv_params = DynamoKvStoredEventParams { let kv_params = DynamoKvStoredEventParams {
event_id, event_id,
token_ids, token_ids,
...@@ -333,7 +346,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( ...@@ -333,7 +346,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
block_ids, block_ids,
num_blocks, num_blocks,
parent_hash, parent_hash,
lora_id, lora_name,
}; };
let publisher = KV_PUB.get().unwrap(); let publisher = KV_PUB.get().unwrap();
let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size()); let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
...@@ -789,7 +802,10 @@ pub unsafe extern "C" fn add_request( ...@@ -789,7 +802,10 @@ pub unsafe extern "C" fn add_request(
let worker = WorkerWithDpRank::new(worker_id, dp_rank); let worker = WorkerWithDpRank::new(worker_id, dp_rank);
// Compute overlap_blocks using the public method // Compute overlap_blocks using the public method
let overlap_blocks = match decode_router.get_overlap_blocks(&tokens, worker).await { let overlap_blocks = match decode_router
.get_overlap_blocks(&tokens, worker, None)
.await
{
Ok(overlap) => overlap, Ok(overlap) => overlap,
Err(e) => { Err(e) => {
tracing::warn!(error = ?e, "Failed to compute overlap, using 0"); tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
......
...@@ -47,12 +47,13 @@ pub fn start_kv_block_indexer_py<'p>( ...@@ -47,12 +47,13 @@ pub fn start_kv_block_indexer_py<'p>(
} }
#[pyfunction] #[pyfunction]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))] #[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None, lora_name=None))]
pub fn compute_block_hash_for_seq_py( pub fn compute_block_hash_for_seq_py(
_py: Python, _py: Python,
tokens: Vec<u32>, tokens: Vec<u32>,
kv_block_size: usize, kv_block_size: usize,
block_mm_infos: Option<Bound<PyAny>>, block_mm_infos: Option<Bound<PyAny>>,
lora_name: Option<String>,
) -> PyResult<Vec<u64>> { ) -> PyResult<Vec<u64>> {
if kv_block_size == 0 { if kv_block_size == 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>( return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
...@@ -65,7 +66,12 @@ pub fn compute_block_hash_for_seq_py( ...@@ -65,7 +66,12 @@ pub fn compute_block_hash_for_seq_py(
.map(depythonize_block_mm_infos) .map(depythonize_block_mm_infos)
.transpose()?; .transpose()?;
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32, mm_infos.as_deref()); let hashes = compute_block_hash_for_seq(
&tokens,
kv_block_size as u32,
mm_infos.as_deref(),
lora_name.as_deref(),
);
Ok(hashes.into_iter().map(|h| h.0).collect()) Ok(hashes.into_iter().map(|h| h.0).collect())
} }
...@@ -169,23 +175,22 @@ impl KvEventPublisher { ...@@ -169,23 +175,22 @@ impl KvEventPublisher {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, block_mm_infos=None))] #[pyo3(signature = (token_ids, num_block_tokens, block_hashes, parent_hash=None, block_mm_infos=None, lora_name=None))]
fn publish_stored( fn publish_stored(
&self, &self,
py: Python, py: Python,
token_ids: Vec<u32>, token_ids: Vec<u32>,
num_block_tokens: Vec<u64>, num_block_tokens: Vec<u64>,
block_hashes: Vec<i64>, block_hashes: Vec<i64>,
lora_id: u64,
parent_hash: Option<i64>, parent_hash: Option<i64>,
block_mm_infos: Option<Bound<PyAny>>, block_mm_infos: Option<Bound<PyAny>>,
lora_name: Option<String>,
) -> PyResult<()> { ) -> PyResult<()> {
let kv_block_size = self.kv_block_size as u32; let kv_block_size = self.kv_block_size as u32;
let dp_rank = self.dp_rank; let dp_rank = self.dp_rank;
let warning_count = self.warning_count.clone(); let warning_count = self.warning_count.clone();
let inner = self.inner.clone(); let inner = self.inner.clone();
// Use shared monotonic event_id counter from the inner publisher
let event_id = inner.next_event_id(); let event_id = inner.next_event_id();
let mm_infos = block_mm_infos let mm_infos = block_mm_infos
...@@ -204,7 +209,7 @@ impl KvEventPublisher { ...@@ -204,7 +209,7 @@ impl KvEventPublisher {
&token_ids, &token_ids,
&num_block_tokens, &num_block_tokens,
&block_hashes_u64, &block_hashes_u64,
lora_id, lora_name.as_deref(),
&warning_count, &warning_count,
mm_infos.as_deref(), mm_infos.as_deref(),
), ),
...@@ -879,7 +884,7 @@ impl KvRouter { ...@@ -879,7 +884,7 @@ impl KvRouter {
Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker)) Self::process_request_to_stream(py, self.inner.clone(), request, Some(tracker))
} }
#[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None))] #[pyo3(signature = (token_ids, router_config_override=None, request_id=None, block_mm_infos=None, lora_name=None))]
fn best_worker<'p>( fn best_worker<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
...@@ -887,6 +892,7 @@ impl KvRouter { ...@@ -887,6 +892,7 @@ impl KvRouter {
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
request_id: Option<String>, request_id: Option<String>,
block_mm_infos: Option<PyObject>, block_mm_infos: Option<PyObject>,
lora_name: Option<String>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override { let router_config_override = if let Some(obj) = router_config_override {
let override_config: llm_rs::kv_router::RouterConfigOverride = let override_config: llm_rs::kv_router::RouterConfigOverride =
...@@ -911,7 +917,7 @@ impl KvRouter { ...@@ -911,7 +917,7 @@ impl KvRouter {
block_mm_infos.as_deref(), block_mm_infos.as_deref(),
router_config_override.as_ref(), router_config_override.as_ref(),
update_states, update_states,
None, // lora_name not exposed in Python API yet lora_name,
0.0, 0.0,
) )
.await .await
...@@ -948,16 +954,18 @@ impl KvRouter { ...@@ -948,16 +954,18 @@ impl KvRouter {
}) })
} }
#[pyo3(signature = (token_ids, lora_name=None))]
fn get_potential_loads<'p>( fn get_potential_loads<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
lora_name: Option<String>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let chooser = self.inner.chooser.clone(); let chooser = self.inner.chooser.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let loads = chooser let loads = chooser
.get_potential_loads(&token_ids, None) .get_potential_loads(&token_ids, None, lora_name.as_deref())
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -236,7 +236,8 @@ class ModelCardInstanceId: ...@@ -236,7 +236,8 @@ class ModelCardInstanceId:
def compute_block_hash_for_seq( def compute_block_hash_for_seq(
tokens: List[int], tokens: List[int],
kv_block_size: int, kv_block_size: int,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
) -> List[int]: ) -> List[int]:
""" """
Compute block hashes for a sequence of tokens, optionally including multimodal metadata. Compute block hashes for a sequence of tokens, optionally including multimodal metadata.
...@@ -562,7 +563,7 @@ class KvIndexer: ...@@ -562,7 +563,7 @@ class KvIndexer:
... ...
def find_matches_for_request( def find_matches_for_request(
self, token_ids: List[int], lora_id: int self, token_ids: List[int], lora_name: Optional[str] = None
) -> OverlapScores: ) -> OverlapScores:
""" """
Return the overlapping scores of workers for the given token ids. Return the overlapping scores of workers for the given token ids.
...@@ -610,13 +611,14 @@ class ApproxKvIndexer: ...@@ -610,13 +611,14 @@ class ApproxKvIndexer:
... ...
def find_matches_for_request( def find_matches_for_request(
self, token_ids: List[int] self, token_ids: List[int], lora_name: Optional[str] = None
) -> OverlapScores: ) -> OverlapScores:
""" """
Return the overlapping scores of workers for the given token ids. Return the overlapping scores of workers for the given token ids.
Args: Args:
token_ids: List of token IDs to find matches for token_ids: List of token IDs to find matches for
lora_name: Optional LoRA adapter name for adapter-aware matching
Returns: Returns:
OverlapScores containing worker matching scores and frequencies OverlapScores containing worker matching scores and frequencies
...@@ -689,9 +691,9 @@ class KvEventPublisher: ...@@ -689,9 +691,9 @@ class KvEventPublisher:
token_ids: List[int], token_ids: List[int],
num_block_tokens: List[int], num_block_tokens: List[int],
block_hashes: List[int], block_hashes: List[int],
lora_id: int,
parent_hash: Optional[int] = None, parent_hash: Optional[int] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None, block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
) -> None: ) -> None:
""" """
Publish a KV stored event. Publish a KV stored event.
...@@ -702,11 +704,11 @@ class KvEventPublisher: ...@@ -702,11 +704,11 @@ class KvEventPublisher:
token_ids: List of token IDs token_ids: List of token IDs
num_block_tokens: Number of tokens per block num_block_tokens: Number of tokens per block
block_hashes: List of block hashes (signed 64-bit integers) block_hashes: List of block hashes (signed 64-bit integers)
lora_id: The LoRA ID
parent_hash: Optional parent hash (signed 64-bit integer) parent_hash: Optional parent hash (signed 64-bit integer)
block_mm_infos: Optional list of multimodal info for each block. block_mm_infos: Optional list of multimodal info for each block.
Each item is either None or a dict with "mm_objects" key containing Each item is either None or a dict with "mm_objects" key containing
a list of {"mm_hash": int, "offsets": [[start, end], ...]} dicts. a list of {"mm_hash": int, "offsets": [[start, end], ...]} dicts.
lora_name: Optional LoRA adapter name for adapter-aware block hashing.
""" """
... ...
...@@ -1388,6 +1390,7 @@ class KvRouter: ...@@ -1388,6 +1390,7 @@ class KvRouter:
router_config_override: Optional[JsonLike] = None, router_config_override: Optional[JsonLike] = None,
request_id: Optional[str] = None, request_id: Optional[str] = None,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None, block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None,
lora_name: Optional[str] = None,
) -> Tuple[int, int, int]: ) -> Tuple[int, int, int]:
""" """
Find the best matching worker for the given tokens. Find the best matching worker for the given tokens.
...@@ -1413,6 +1416,7 @@ class KvRouter: ...@@ -1413,6 +1416,7 @@ class KvRouter:
async def get_potential_loads( async def get_potential_loads(
self, self,
token_ids: List[int], token_ids: List[int],
lora_name: Optional[str] = None,
) -> List[Dict[str, int]]: ) -> List[Dict[str, int]]:
""" """
Get potential prefill and decode loads for all workers. Get potential prefill and decode loads for all workers.
......
...@@ -10,13 +10,14 @@ Every cached KV block in a distributed LLM system needs four pieces of informati ...@@ -10,13 +10,14 @@ Every cached KV block in a distributed LLM system needs four pieces of informati
### 1. Local Block Hash (`LocalBlockHash`, u64) ### 1. Local Block Hash (`LocalBlockHash`, u64)
**What**: Hash of the tokens *within* a single block (e.g., 64 tokens). **What**: Hash of the tokens *within* a single block (e.g., 64 tokens), optionally including LoRA adapter name and multimodal metadata.
**Why**: Identifies the content of this specific block, independent of context. Two blocks with the same tokens have the same local hash. **Why**: Identifies the content of this specific block, independent of context. Two blocks with the same tokens (and same LoRA adapter) have the same local hash. When a LoRA adapter name is provided, it is length-prefixed and appended to the byte buffer before hashing, ensuring that blocks under different adapters (or the base model) always produce distinct hashes.
``` ```text
Block at position 5: tokens [101, 102, 103, ...] Block at position 5: tokens [101, 102, 103, ...]
LocalBlockHash = hash(tokens) = 0xABCD1234 LocalBlockHash = hash(tokens) = 0xABCD1234 (base model)
LocalBlockHash = hash(tokens || len("my-lora") || "my-lora") = 0xDEAD5678 (LoRA adapter)
``` ```
### 2. External Sequence Block Hash (`ExternalSequenceBlockHash`, u64) ### 2. External Sequence Block Hash (`ExternalSequenceBlockHash`, u64)
...@@ -39,6 +40,8 @@ block2 in B: seq_hash = hash(hash(hash(block0') || block1') || block2) = 0x2222 ...@@ -39,6 +40,8 @@ block2 in B: seq_hash = hash(hash(hash(block0') || block1') || block2) = 0x2222
> >
> In practice, the `ExternalSequenceBlockHash` may come directly from the inference engine (e.g., vLLM, TensorRT-LLM) using a rolling hash algorithm that we don't know or control. The engine computes these hashes internally and reports them via KV cache events. > In practice, the `ExternalSequenceBlockHash` may come directly from the inference engine (e.g., vLLM, TensorRT-LLM) using a rolling hash algorithm that we don't know or control. The engine computes these hashes internally and reports them via KV cache events.
> >
> **LoRA identity**: The engine is responsible for incorporating the LoRA adapter identity into the `ExternalSequenceBlockHash` before emitting KV events. Dynamo does not add LoRA information at the router layer. For example, vLLM does this via `_gen_lora_extra_hash_keys`, which appends the LoRA ID as extra keys when calling `hash_block_tokens(..., extra_keys)`. Any engine integrating with the KV router must follow the same convention to ensure correct cache isolation between LoRA adapters.
>
> **Implications for index implementations:** > **Implications for index implementations:**
> >
> - **RadixTree**: Can handle engine-provided hashes because it traverses the tree structure using `LocalBlockHash` for navigation and only uses `ExternalSequenceBlockHash` as an opaque identifier for lookups. It doesn't need to recompute hashes. > - **RadixTree**: Can handle engine-provided hashes because it traverses the tree structure using `LocalBlockHash` for navigation and only uses `ExternalSequenceBlockHash` as an opaque identifier for lookups. It doesn't need to recompute hashes.
......
...@@ -302,7 +302,7 @@ fn bench_hash(args: &Args) { ...@@ -302,7 +302,7 @@ fn bench_hash(args: &Args) {
for (i, tokens) in token_sequences.iter().enumerate() { for (i, tokens) in token_sequences.iter().enumerate() {
let start = Instant::now(); let start = Instant::now();
let _ = compute_block_hash_for_seq(tokens, args.block_size, None); let _ = compute_block_hash_for_seq(tokens, args.block_size, None, None);
let elapsed = start.elapsed(); let elapsed = start.elapsed();
if i >= warmup_iters { if i >= warmup_iters {
......
...@@ -349,7 +349,7 @@ mod tests { ...@@ -349,7 +349,7 @@ mod tests {
// 1. Before routing decision there should be no matches // 1. Before routing decision there should be no matches
let pre_scores = indexer let pre_scores = indexer
.find_matches_for_request(&tokens) .find_matches_for_request(&tokens, None)
.await .await
.expect("indexer offline"); .expect("indexer offline");
assert!(pre_scores.scores.is_empty()); assert!(pre_scores.scores.is_empty());
...@@ -366,7 +366,10 @@ mod tests { ...@@ -366,7 +366,10 @@ mod tests {
// Poll until we observe the match being registered // Poll until we observe the match being registered
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
s.scores s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_id)) .get(&WorkerWithDpRank::from_worker_id(worker_id))
.copied() .copied()
...@@ -376,7 +379,10 @@ mod tests { ...@@ -376,7 +379,10 @@ mod tests {
// 3. After the TTL has passed the entry should expire automatically // 3. After the TTL has passed the entry should expire automatically
time::sleep(TTL + Duration::from_millis(50)).await; time::sleep(TTL + Duration::from_millis(50)).await;
let post_scores = indexer.find_matches_for_request(&tokens).await.unwrap(); let post_scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert!(post_scores.scores.is_empty()); assert!(post_scores.scores.is_empty());
} }
...@@ -413,7 +419,10 @@ mod tests { ...@@ -413,7 +419,10 @@ mod tests {
// Wait until the worker is registered // Wait until the worker is registered
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
s.scores s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id)) .contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
}) })
...@@ -424,7 +433,10 @@ mod tests { ...@@ -424,7 +433,10 @@ mod tests {
// Ensure the worker's entries are gone // Ensure the worker's entries are gone
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
!s.scores !s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_id)) .contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
}) })
...@@ -475,7 +487,10 @@ mod tests { ...@@ -475,7 +487,10 @@ mod tests {
// Ensure both workers are registered // Ensure both workers are registered
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
s.scores s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0)) .get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied() .copied()
...@@ -492,7 +507,10 @@ mod tests { ...@@ -492,7 +507,10 @@ mod tests {
// Confirm the removed worker is gone, and the other remains. // Confirm the removed worker is gone, and the other remains.
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
!s.scores !s.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)) .contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
&& s.scores && s.scores
...@@ -539,7 +557,10 @@ mod tests { ...@@ -539,7 +557,10 @@ mod tests {
// Ensure the indexer has registered the block // Ensure the indexer has registered the block
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&seq_a).await.unwrap(); let s = indexer
.find_matches_for_request(&seq_a, None)
.await
.unwrap();
s.scores s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_a)) .get(&WorkerWithDpRank::from_worker_id(worker_a))
.copied() .copied()
...@@ -551,7 +572,10 @@ mod tests { ...@@ -551,7 +572,10 @@ mod tests {
let seq_b: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8]; let seq_b: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
// Query the indexer for overlaps of Sequence B (before it has been routed anywhere) // Query the indexer for overlaps of Sequence B (before it has been routed anywhere)
let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap(); let overlap = indexer
.find_matches_for_request(&seq_b, None)
.await
.unwrap();
// Expect worker A to have an overlap score of 1 (shared first block) // Expect worker A to have an overlap score of 1 (shared first block)
assert_eq!( assert_eq!(
...@@ -606,7 +630,10 @@ mod tests { ...@@ -606,7 +630,10 @@ mod tests {
// Wait until both workers are reflected in overlap scores // Wait until both workers are reflected in overlap scores
spin_until(Duration::from_millis(100), || async { spin_until(Duration::from_millis(100), || async {
let s = indexer.find_matches_for_request(&tokens).await.unwrap(); let s = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
s.scores s.scores
.get(&WorkerWithDpRank::from_worker_id(worker_0)) .get(&WorkerWithDpRank::from_worker_id(worker_0))
.copied() .copied()
...@@ -618,7 +645,10 @@ mod tests { ...@@ -618,7 +645,10 @@ mod tests {
}) })
.await; .await;
let scores = indexer.find_matches_for_request(&tokens).await.unwrap(); let scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert_eq!( assert_eq!(
scores scores
...@@ -777,7 +807,10 @@ mod tests { ...@@ -777,7 +807,10 @@ mod tests {
// Verify all 5 blocks are present (no pruning yet) // Verify all 5 blocks are present (no pruning yet)
for i in 0..5 { for i in 0..5 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3]; let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer.find_matches_for_request(&tokens).await.unwrap(); let scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert_eq!( assert_eq!(
scores.scores.get(&worker).copied(), scores.scores.get(&worker).copied(),
Some(1), Some(1),
...@@ -803,7 +836,10 @@ mod tests { ...@@ -803,7 +836,10 @@ mod tests {
// Verify that the 4 oldest blocks are pruned // Verify that the 4 oldest blocks are pruned
for i in 0..4 { for i in 0..4 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3]; let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer.find_matches_for_request(&tokens).await.unwrap(); let scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert!( assert!(
scores.scores.get(&worker).copied().unwrap_or(0) == 0, scores.scores.get(&worker).copied().unwrap_or(0) == 0,
"Block {} should have been pruned but is still present", "Block {} should have been pruned but is still present",
...@@ -814,7 +850,10 @@ mod tests { ...@@ -814,7 +850,10 @@ mod tests {
// Verify the 2 newest blocks are present // Verify the 2 newest blocks are present
for i in 4..6 { for i in 4..6 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3]; let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer.find_matches_for_request(&tokens).await.unwrap(); let scores = indexer
.find_matches_for_request(&tokens, None)
.await
.unwrap();
assert_eq!( assert_eq!(
scores.scores.get(&worker).copied(), scores.scores.get(&worker).copied(),
Some(1), Some(1),
......
...@@ -1122,7 +1122,9 @@ mod tests { ...@@ -1122,7 +1122,9 @@ mod tests {
tokio::time::sleep(Duration::from_millis(100)).await; tokio::time::sleep(Duration::from_millis(100)).await;
let scores = indexer.find_matches_for_request(&[100, 200, 300]).await; let scores = indexer
.find_matches_for_request(&[100, 200, 300], None)
.await;
assert!(scores.is_ok()); assert!(scores.is_ok());
indexer.shutdown(); indexer.shutdown();
......
...@@ -302,6 +302,7 @@ pub trait KvIndexerInterface { ...@@ -302,6 +302,7 @@ pub trait KvIndexerInterface {
/// ### Arguments /// ### Arguments
/// ///
/// * `tokens` - A vector of `u32` tokens. /// * `tokens` - A vector of `u32` tokens.
/// * `lora_name` - Optional LoRA adapter name to include in block hash computation.
/// ///
/// ### Returns /// ### Returns
/// ///
...@@ -309,6 +310,7 @@ pub trait KvIndexerInterface { ...@@ -309,6 +310,7 @@ pub trait KvIndexerInterface {
async fn find_matches_for_request( async fn find_matches_for_request(
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError>; ) -> Result<OverlapScores, KvRouterError>;
/// Apply a `RouterEvent` to the KV store. /// Apply a `RouterEvent` to the KV store.
...@@ -510,8 +512,9 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -510,8 +512,9 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
async fn find_matches_for_request( async fn find_matches_for_request(
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None); let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name);
Ok(self.backend.find_matches(&sequence, false)) Ok(self.backend.find_matches(&sequence, false))
} }
...@@ -972,13 +975,14 @@ impl KvIndexerInterface for KvIndexer { ...@@ -972,13 +975,14 @@ impl KvIndexerInterface for KvIndexer {
async fn find_matches_for_request( async fn find_matches_for_request(
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
tracing::debug!( tracing::debug!(
"Finding matches for request tokens: {:?} / len: {}", "Finding matches for request tokens: {:?} / len: {}",
tokens, tokens,
tokens.len() tokens.len()
); );
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None); let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name);
tracing::debug!("Computed sequence: {:?}", sequence); tracing::debug!("Computed sequence: {:?}", sequence);
self.find_matches(sequence).await self.find_matches(sequence).await
} }
...@@ -1307,8 +1311,11 @@ impl KvIndexerInterface for LocalKvIndexer { ...@@ -1307,8 +1311,11 @@ impl KvIndexerInterface for LocalKvIndexer {
async fn find_matches_for_request( async fn find_matches_for_request(
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
self.indexer.find_matches_for_request(tokens).await self.indexer
.find_matches_for_request(tokens, lora_name)
.await
} }
async fn apply_event(&self, event: RouterEvent) { async fn apply_event(&self, event: RouterEvent) {
...@@ -1760,8 +1767,9 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -1760,8 +1767,9 @@ impl KvIndexerInterface for KvIndexerSharded {
async fn find_matches_for_request( async fn find_matches_for_request(
&self, &self,
tokens: &[u32], tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None); let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None, lora_name);
self.find_matches(sequence).await self.find_matches(sequence).await
} }
...@@ -2350,7 +2358,7 @@ mod tests { ...@@ -2350,7 +2358,7 @@ mod tests {
// Empty index should return no matches // Empty index should return no matches
let tokens = vec![1, 2, 3, 4]; let tokens = vec![1, 2, 3, 4];
let scores = index.find_matches_for_request(&tokens).await.unwrap(); let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
assert!(scores.scores.is_empty()); assert!(scores.scores.is_empty());
// Store some data and verify we can find it via tokens // Store some data and verify we can find it via tokens
...@@ -2362,7 +2370,7 @@ mod tests { ...@@ -2362,7 +2370,7 @@ mod tests {
// Note: find_matches_for_request computes block hashes from tokens, // Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values. // so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error. // For this test, we just verify the method works without error.
let scores = index.find_matches_for_request(&tokens).await.unwrap(); let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes // The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens // because find_matches_for_request computes different hashes from raw tokens
assert!(scores.scores.is_empty() || !scores.scores.is_empty()); assert!(scores.scores.is_empty() || !scores.scores.is_empty());
......
...@@ -174,6 +174,7 @@ impl KvIndexerInterface for NaiveNestedMap { ...@@ -174,6 +174,7 @@ impl KvIndexerInterface for NaiveNestedMap {
async fn find_matches_for_request( async fn find_matches_for_request(
&self, &self,
_tokens: &[u32], _tokens: &[u32],
_lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
unimplemented!("not used in bench") unimplemented!("not used in bench")
} }
...@@ -384,6 +385,7 @@ impl KvIndexerInterface for InvertedIndex { ...@@ -384,6 +385,7 @@ impl KvIndexerInterface for InvertedIndex {
async fn find_matches_for_request( async fn find_matches_for_request(
&self, &self,
_tokens: &[u32], _tokens: &[u32],
_lora_name: Option<&str>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
unimplemented!("not used in bench") unimplemented!("not used in bench")
} }
......
...@@ -19,23 +19,34 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash { ...@@ -19,23 +19,34 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(compute_hash(data)) LocalBlockHash(compute_hash(data))
} }
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata. /// Compute the hash for a sequence of tokens, optionally including multimodal metadata
/// and LoRA adapter identity.
/// ///
/// When multimodal extra info is provided, the mm_hashes are included in the hash computation /// When multimodal extra info is provided, the mm_hashes are included in the hash computation
/// to ensure that blocks with identical tokens but different multimodal objects produce /// to ensure that blocks with identical tokens but different multimodal objects produce
/// different hashes. /// different hashes.
///
/// When `lora_name` is provided, the adapter name is mixed into the XXH3 seed so that
/// blocks cached under different LoRA adapters (or the base model) produce distinct hashes.
/// Because LoRA identity applies uniformly to every block in a sequence, encoding it in the
/// seed is more efficient than appending per-block bytes and matches the approach used by
/// KVBM's `SaltHash`.
pub fn compute_block_hash_for_seq( pub fn compute_block_hash_for_seq(
tokens: &[u32], tokens: &[u32],
kv_block_size: u32, kv_block_size: u32,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>, block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
lora_name: Option<&str>,
) -> Vec<LocalBlockHash> { ) -> Vec<LocalBlockHash> {
let seed = match lora_name.filter(|n| !n.is_empty()) {
Some(name) => XXH3_SEED.wrapping_add(xxh3::xxh3_64(name.as_bytes())),
None => XXH3_SEED,
};
tokens tokens
.chunks_exact(kv_block_size as usize) .chunks_exact(kv_block_size as usize)
.enumerate() .enumerate()
.map(|(block_idx, chunk)| { .map(|(block_idx, chunk)| {
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect(); let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
// Include MM hashes in the block hash computation if present
if let Some(mm_infos) = block_mm_infos if let Some(mm_infos) = block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx) && let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{ {
...@@ -51,7 +62,7 @@ pub fn compute_block_hash_for_seq( ...@@ -51,7 +62,7 @@ pub fn compute_block_hash_for_seq(
} }
} }
compute_block_hash(&bytes) LocalBlockHash(xxh3::xxh3_64_with_seed(&bytes, seed))
}) })
.collect() .collect()
} }
...@@ -176,13 +187,13 @@ pub struct ActiveLoad { ...@@ -176,13 +187,13 @@ pub struct ActiveLoad {
pub active_prefill_tokens: Option<u64>, pub active_prefill_tokens: Option<u64>,
} }
/// A [`LocalBlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional /// A [`LocalBlockHash`] is a hash computed from the token IDs, optional multimodal metadata,
/// lora_id of a block. /// and optional LoRA adapter name of a block.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct LocalBlockHash(pub u64); pub struct LocalBlockHash(pub u64);
/// A sequence aware hash of a block where the hash is computed from the tokens_ids, extra_token_ids /// A sequence-aware hash of a block computed by the engine from token IDs, optional metadata,
/// and the optional lora_id of a block, PLUS the hash of the parent block. /// and the hash of the parent block.
/// ///
/// In this case, the hashing function is external and unknown. /// In this case, the hashing function is external and unknown.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
...@@ -575,6 +586,7 @@ pub struct TokensWithHashes { ...@@ -575,6 +586,7 @@ pub struct TokensWithHashes {
tokens: Vec<u32>, tokens: Vec<u32>,
block_size: u32, block_size: u32,
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>, block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
lora_name: Option<String>,
block_hashes: Option<Vec<LocalBlockHash>>, block_hashes: Option<Vec<LocalBlockHash>>,
seq_hashes: Option<Vec<SequenceHash>>, seq_hashes: Option<Vec<SequenceHash>>,
} }
...@@ -586,6 +598,7 @@ impl TokensWithHashes { ...@@ -586,6 +598,7 @@ impl TokensWithHashes {
tokens, tokens,
block_size, block_size,
block_mm_infos: None, block_mm_infos: None,
lora_name: None,
block_hashes: None, block_hashes: None,
seq_hashes: None, seq_hashes: None,
} }
...@@ -597,6 +610,12 @@ impl TokensWithHashes { ...@@ -597,6 +610,12 @@ impl TokensWithHashes {
self self
} }
/// Sets the LoRA adapter name for hash computation.
pub fn with_lora_name(mut self, name: String) -> Self {
self.lora_name = Some(name);
self
}
/// Returns a reference to the tokens. /// Returns a reference to the tokens.
pub fn tokens(&self) -> &[u32] { pub fn tokens(&self) -> &[u32] {
&self.tokens &self.tokens
...@@ -629,6 +648,7 @@ impl TokensWithHashes { ...@@ -629,6 +648,7 @@ impl TokensWithHashes {
&self.tokens, &self.tokens,
self.block_size, self.block_size,
self.block_mm_infos.as_deref(), self.block_mm_infos.as_deref(),
self.lora_name.as_deref(),
)); ));
} }
self.block_hashes.as_ref().unwrap() self.block_hashes.as_ref().unwrap()
...@@ -708,22 +728,67 @@ mod tests { ...@@ -708,22 +728,67 @@ mod tests {
#[case(32)] #[case(32)]
#[case(64)] #[case(64)]
fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) { fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
// create a sequence of kv_block_size elements
let sequence = (0..kv_block_size).collect::<Vec<u32>>(); let sequence = (0..kv_block_size).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None);
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
// create a sequence of kv_block_size + 1 elements
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>(); let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None);
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
// create a sequence of 2 * kv_block_size + 1 elements
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>(); let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None, None);
assert_eq!(hashes.len(), 2); assert_eq!(hashes.len(), 2);
} }
#[test]
fn test_lora_name_produces_different_hash() {
let tokens: Vec<u32> = (0..4).collect();
let base = compute_block_hash_for_seq(&tokens, 4, None, None);
let lora_a = compute_block_hash_for_seq(&tokens, 4, None, Some("adapter-a"));
let lora_b = compute_block_hash_for_seq(&tokens, 4, None, Some("adapter-b"));
assert_ne!(base[0], lora_a[0]);
assert_ne!(base[0], lora_b[0]);
assert_ne!(lora_a[0], lora_b[0]);
}
#[test]
fn test_lora_name_none_matches_legacy() {
let tokens: Vec<u32> = (0..8).collect();
let hashes_none = compute_block_hash_for_seq(&tokens, 4, None, None);
let hashes_none2 = compute_block_hash_for_seq(&tokens, 4, None, None);
assert_eq!(hashes_none, hashes_none2);
}
#[test]
fn test_lora_name_empty_string_normalized_to_none() {
let tokens: Vec<u32> = (0..4).collect();
let base = compute_block_hash_for_seq(&tokens, 4, None, None);
let empty = compute_block_hash_for_seq(&tokens, 4, None, Some(""));
assert_eq!(
base, empty,
"empty lora_name should be treated as base model"
);
}
#[test]
fn test_tokens_with_hashes_lora() {
let tokens: Vec<u32> = (0..8).collect();
let mut base = TokensWithHashes::new(tokens.clone(), 4);
let base_hashes = base.get_or_compute_block_hashes().to_vec();
let mut with_lora =
TokensWithHashes::new(tokens, 4).with_lora_name("my-adapter".to_string());
let lora_hashes = with_lora.get_or_compute_block_hashes().to_vec();
assert_eq!(base_hashes.len(), lora_hashes.len());
for (b, l) in base_hashes.iter().zip(lora_hashes.iter()) {
assert_ne!(b, l);
}
}
#[test] #[test]
fn test_local_block_hash_serialization() { fn test_local_block_hash_serialization() {
let hash = LocalBlockHash(12345); let hash = LocalBlockHash(12345);
......
...@@ -223,7 +223,7 @@ impl DynamoEventManager { ...@@ -223,7 +223,7 @@ impl DynamoEventManager {
tokens, tokens,
parent_hash, parent_hash,
block_size, block_size,
None, // lora_id None, // lora_name
None, // tier None, // tier
None, // data_parallel_rank None, // data_parallel_rank
) )
......
...@@ -40,7 +40,7 @@ impl KvEventConsolidatorHandle { ...@@ -40,7 +40,7 @@ impl KvEventConsolidatorHandle {
token_ids: Vec<u32>, token_ids: Vec<u32>,
parent_hash: Option<String>, parent_hash: Option<String>,
block_size: usize, block_size: usize,
lora_id: Option<u64>, lora_name: Option<String>,
tier: Option<StorageTier>, tier: Option<StorageTier>,
data_parallel_rank: Option<i32>, data_parallel_rank: Option<i32>,
) { ) {
...@@ -51,7 +51,7 @@ impl KvEventConsolidatorHandle { ...@@ -51,7 +51,7 @@ impl KvEventConsolidatorHandle {
token_ids, token_ids,
parent_hash, parent_hash,
block_size, block_size,
lora_id.map(|id| id as i32), lora_name,
tier, tier,
data_parallel_rank, data_parallel_rank,
); );
......
...@@ -31,7 +31,6 @@ struct EventBatch( ...@@ -31,7 +31,6 @@ struct EventBatch(
); );
/// Event types matching vLLM's format /// Event types matching vLLM's format
/// Note: Uses i32 for token_ids, block_size, and lora_id to match vLLM's ZmqEventPublisher
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
enum Event { enum Event {
...@@ -41,7 +40,8 @@ enum Event { ...@@ -41,7 +40,8 @@ enum Event {
parent_block_hash: Option<u64>, parent_block_hash: Option<u64>,
token_ids: Vec<i32>, token_ids: Vec<i32>,
block_size: i32, block_size: i32,
lora_id: Option<i32>, #[serde(default, skip_serializing_if = "Option::is_none")]
lora_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
medium: Option<String>, medium: Option<String>,
}, },
...@@ -68,15 +68,13 @@ impl Event { ...@@ -68,15 +68,13 @@ impl Event {
parent_hash, parent_hash,
token_ids, token_ids,
block_size, block_size,
lora_id, lora_name,
source: _, // Source used for logging only, not sent to router source: _,
} => { } => {
// Parse block hash - fail if invalid to prevent corruption
let parsed_hash = block_hash let parsed_hash = block_hash
.parse::<u64>() .parse::<u64>()
.with_context(|| format!("Failed to parse block_hash: {}", block_hash))?; .with_context(|| format!("Failed to parse block_hash: {}", block_hash))?;
// Parse parent hash if present - fail if invalid
let parsed_parent = parent_hash let parsed_parent = parent_hash
.map(|h| { .map(|h| {
h.parse::<u64>() h.parse::<u64>()
...@@ -84,8 +82,6 @@ impl Event { ...@@ -84,8 +82,6 @@ impl Event {
}) })
.transpose()?; .transpose()?;
// Convert u32 token_ids to i32 for vLLM compatibility
// Token IDs should never exceed i32::MAX in practice, but we handle it gracefully
let token_ids_i32: Vec<i32> = token_ids let token_ids_i32: Vec<i32> = token_ids
.into_iter() .into_iter()
.map(|t| { .map(|t| {
...@@ -96,7 +92,6 @@ impl Event { ...@@ -96,7 +92,6 @@ impl Event {
}) })
.collect(); .collect();
// Convert usize block_size to i32 for vLLM compatibility
let block_size_i32 = i32::try_from(block_size).unwrap_or_else(|_| { let block_size_i32 = i32::try_from(block_size).unwrap_or_else(|_| {
tracing::warn!( tracing::warn!(
"Block size {} exceeds i32::MAX, clamping to i32::MAX", "Block size {} exceeds i32::MAX, clamping to i32::MAX",
...@@ -105,16 +100,13 @@ impl Event { ...@@ -105,16 +100,13 @@ impl Event {
i32::MAX i32::MAX
}); });
// lora_id is already Option<i32> in ConsolidatedEvent::Store
let lora_id_i32 = lora_id;
Ok(Event::BlockStored { Ok(Event::BlockStored {
block_hashes: vec![parsed_hash], block_hashes: vec![parsed_hash],
parent_block_hash: parsed_parent, parent_block_hash: parsed_parent,
token_ids: token_ids_i32, token_ids: token_ids_i32,
block_size: block_size_i32, block_size: block_size_i32,
lora_id: lora_id_i32, lora_name,
medium: None, // Not provided by ConsolidatedEvent medium: None,
}) })
} }
ConsolidatedEvent::Remove { ConsolidatedEvent::Remove {
......
...@@ -145,12 +145,9 @@ enum VllmRawEvent { ...@@ -145,12 +145,9 @@ enum VllmRawEvent {
parent_block_hash: Option<BlockHash>, parent_block_hash: Option<BlockHash>,
token_ids: Vec<i32>, token_ids: Vec<i32>,
block_size: i32, block_size: i32,
lora_id: Option<i32>,
#[serde(default)] #[serde(default)]
medium: Option<String>, medium: Option<String>,
#[serde(default)] #[serde(default)]
#[allow(dead_code)]
// Reserved for future use, needed for vLLM 0.14.0 deserialization
lora_name: Option<String>, lora_name: Option<String>,
}, },
#[serde(rename = "BlockRemoved")] #[serde(rename = "BlockRemoved")]
...@@ -279,9 +276,8 @@ fn process_event( ...@@ -279,9 +276,8 @@ fn process_event(
parent_block_hash, parent_block_hash,
token_ids, token_ids,
block_size, block_size,
lora_id,
medium, medium,
lora_name: _, // Not used yet, lora_id is still used for backwards compat lora_name,
} => { } => {
let storage_tier = medium let storage_tier = medium
.as_ref() .as_ref()
...@@ -370,7 +366,7 @@ fn process_event( ...@@ -370,7 +366,7 @@ fn process_event(
block_tokens, block_tokens,
current_parent.clone(), current_parent.clone(),
block_size_usize, block_size_usize,
lora_id, lora_name.clone(),
Some(storage_tier), Some(storage_tier),
data_parallel_rank, data_parallel_rank,
); );
......
...@@ -188,8 +188,8 @@ pub enum ConsolidatedEvent { ...@@ -188,8 +188,8 @@ pub enum ConsolidatedEvent {
parent_hash: Option<String>, parent_hash: Option<String>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
block_size: usize, block_size: usize,
lora_id: Option<i32>, lora_name: Option<String>,
source: String, // The source where it was first stored (vllm or kvbm) source: String,
}, },
/// Block removed (removed from all sources) /// Block removed (removed from all sources)
Remove { Remove {
...@@ -257,7 +257,7 @@ impl CacheStatusTracker { ...@@ -257,7 +257,7 @@ impl CacheStatusTracker {
token_ids: Vec<u32>, token_ids: Vec<u32>,
parent_hash: Option<String>, parent_hash: Option<String>,
block_size: usize, block_size: usize,
lora_id: Option<i32>, lora_name: Option<String>,
tier: Option<StorageTier>, tier: Option<StorageTier>,
data_parallel_rank: Option<i32>, data_parallel_rank: Option<i32>,
) -> bool { ) -> bool {
...@@ -327,7 +327,7 @@ impl CacheStatusTracker { ...@@ -327,7 +327,7 @@ impl CacheStatusTracker {
.as_ref() .as_ref()
.map(|p| &p[..16.min(p.len())]) .map(|p| &p[..16.min(p.len())])
.unwrap_or("none"), .unwrap_or("none"),
lora_id, lora_name,
data_parallel_rank, data_parallel_rank,
&token_ids &token_ids
); );
...@@ -366,7 +366,7 @@ impl CacheStatusTracker { ...@@ -366,7 +366,7 @@ impl CacheStatusTracker {
parent_hash: resolved_parent_hash, parent_hash: resolved_parent_hash,
token_ids, token_ids,
block_size, block_size,
lora_id, lora_name,
source: source.to_str().to_string(), source: source.to_str().to_string(),
}); });
......
...@@ -370,8 +370,14 @@ impl KvRouter { ...@@ -370,8 +370,14 @@ impl KvRouter {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = tracing::info_span!("kv_router.compute_block_hashes") let block_hashes = tracing::info_span!("kv_router.compute_block_hashes").in_scope(|| {
.in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, block_mm_infos)); compute_block_hash_for_seq(
tokens,
self.block_size,
block_mm_infos,
lora_name.as_deref(),
)
});
let hash_elapsed = start.elapsed(); let hash_elapsed = start.elapsed();
let overlap_scores = self let overlap_scores = self
...@@ -381,12 +387,12 @@ impl KvRouter { ...@@ -381,12 +387,12 @@ impl KvRouter {
.await?; .await?;
let find_matches_elapsed = start.elapsed(); let find_matches_elapsed = start.elapsed();
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| { let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
self.kv_router_config.compute_seq_hashes_for_tracking( self.kv_router_config.compute_seq_hashes_for_tracking(
tokens, tokens,
self.block_size, self.block_size,
router_config_override, router_config_override,
lora_name.as_deref(),
) )
}); });
let seq_hash_elapsed = start.elapsed(); let seq_hash_elapsed = start.elapsed();
...@@ -453,6 +459,7 @@ impl KvRouter { ...@@ -453,6 +459,7 @@ impl KvRouter {
tokens, tokens,
self.block_size, self.block_size,
router_config_override, router_config_override,
lora_name.as_deref(),
); );
if let Err(e) = self if let Err(e) = self
...@@ -506,8 +513,9 @@ impl KvRouter { ...@@ -506,8 +513,9 @@ impl KvRouter {
&self, &self,
tokens: &[u32], tokens: &[u32],
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
lora_name: Option<&str>,
) -> Result<u32, KvRouterError> { ) -> Result<u32, KvRouterError> {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None, lora_name);
let overlap_scores = self.indexer.find_matches(block_hashes).await?; let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0)) Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
} }
...@@ -517,15 +525,17 @@ impl KvRouter { ...@@ -517,15 +525,17 @@ impl KvRouter {
&self, &self,
tokens: &[u32], tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
lora_name: Option<&str>,
) -> Result<Vec<PotentialLoad>> { ) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None, lora_name);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?; let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking( let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
tokens, tokens,
self.block_size, self.block_size,
router_config_override, router_config_override,
lora_name,
); );
Ok(self Ok(self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment