Unverified Commit 18986010 authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add KV event consolidator for KVBM (vllm) and router integration (#3725)


Signed-off-by: default avatarkrishung5 <krish@nvidia.com>
parent 95214e8b
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Set deterministic hash for KV event IDs
export PYTHONHASHSEED=0
# Common configuration
MODEL="Qwen/Qwen3-0.6B"
# run frontend + KV router
python -m dynamo.frontend \
--router-mode kv \
--http-port 8000 \
--router-reset-states &
# run workers with KVBM enabled
# --enforce-eager is added for quick deployment. for production use, need to remove this flag
# Each worker needs unique ZMQ ports to avoid KVBM coordination conflicts
DYN_KVBM_LEADER_ZMQ_PUB_PORT=56001 \
DYN_KVBM_LEADER_ZMQ_ACK_PORT=56002 \
CUDA_VISIBLE_DEVICES=0 DYN_KVBM_CPU_CACHE_GB=2 \
python3 -m dynamo.vllm \
--model $MODEL \
--enforce-eager \
--connector kvbm --gpu-memory-utilization 0.4 &
DYN_KVBM_LEADER_ZMQ_PUB_PORT=56003 \
DYN_KVBM_LEADER_ZMQ_ACK_PORT=56004 \
CUDA_VISIBLE_DEVICES=0 DYN_KVBM_CPU_CACHE_GB=2 \
python3 -m dynamo.vllm \
--model $MODEL \
--enforce-eager \
--connector kvbm --gpu-memory-utilization 0.4
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Set deterministic hash for KV event IDs
export PYTHONHASHSEED=0
# Common configuration
MODEL="Qwen/Qwen3-0.6B"
# run decode router with kv-overlap-score-weight 0 for pure load balancing
python -m dynamo.frontend \
--router-mode kv \
--http-port 8000 \
--kv-overlap-score-weight 0 \
--router-reset-states &
# run standalone router service for prefill workers
python -m dynamo.router \
--endpoint dynamo.prefill.generate \
--router-reset-states \
--no-track-active-blocks &
# two decode workers (without KVBM)
# --enforce-eager is added for quick deployment. for production use, need to remove this flag
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \
--model $MODEL \
--enforce-eager &
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \
--model $MODEL \
--enforce-eager &
# two prefill workers with KVBM enabled
# Each worker needs unique ZMQ ports to avoid KVBM coordination conflicts
DYN_KVBM_LEADER_ZMQ_PUB_PORT=56001 \
DYN_KVBM_LEADER_ZMQ_ACK_PORT=56002 \
CUDA_VISIBLE_DEVICES=2 DYN_KVBM_CPU_CACHE_GB=20 \
python3 -m dynamo.vllm \
--model $MODEL \
--enforce-eager \
--is-prefill-worker \
--connector kvbm &
DYN_KVBM_LEADER_ZMQ_PUB_PORT=56003 \
DYN_KVBM_LEADER_ZMQ_ACK_PORT=56004 \
CUDA_VISIBLE_DEVICES=3 DYN_KVBM_CPU_CACHE_GB=20 \
python3 -m dynamo.vllm \
--model $MODEL \
--enforce-eager \
--is-prefill-worker \
--connector kvbm
......@@ -73,6 +73,18 @@ class Config:
# dump config to file
dump_config_to: Optional[str] = None
def has_connector(self, connector_name: str) -> bool:
"""
Check if a specific connector is enabled.
Args:
connector_name: Name of the connector to check (e.g., "kvbm", "nixl")
Returns:
True if the connector is in the connector list, False otherwise
"""
return self.connector_list is not None and connector_name in self.connector_list
@register_encoder(Config)
def _preprocess_for_encode_config(config: Config) -> Dict[str, Any]:
......@@ -311,7 +323,7 @@ async def configure_ports(runtime: DistributedRuntime, config: Config):
logger.info(f"Allocated ZMQ KV events port: {kv_port} (worker_id={worker_id})")
# Check if NIXL is needed based on connector list
needs_nixl = config.connector_list and "nixl" in config.connector_list
needs_nixl = config.has_connector("nixl")
if needs_nixl:
# Allocate side channel ports
......
......@@ -5,6 +5,7 @@ import asyncio
import logging
import os
import signal
from typing import Optional
import uvloop
from prometheus_client import REGISTRY
......@@ -24,6 +25,7 @@ from dynamo.llm import (
fetch_llm,
register_llm,
)
from dynamo.llm.vllm_integration.consolidator_config import get_consolidator_endpoints
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.multimodal_handlers import (
......@@ -122,11 +124,21 @@ def setup_kv_event_publisher(
component,
generate_endpoint,
vllm_config,
):
consolidator_enabled: bool = False,
consolidator_port: Optional[int] = 5558,
) -> Optional[ZmqKvEventPublisher]:
"""
Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Args:
config: Worker configuration
component: Component for runtime integration
generate_endpoint: Endpoint for worker ID
vllm_config: vLLM configuration
consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
consolidator_port: Port where kv event consolidator publishes (default: 5558)
Returns:
List of ZmqKvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
"""
......@@ -143,11 +155,21 @@ def setup_kv_event_publisher(
kv_publishers = []
for dp_rank in range(data_parallel_size):
if consolidator_enabled:
# TODO: Use different port for each dp_rank once KVBM supports DP
zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
logger.info(
f"KV event publisher for dp_rank={dp_rank} subscribing to consolidator at {zmq_endpoint}"
)
else:
# Each dp_rank publishes to a different port
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=dp_rank,
).replace("*", "127.0.0.1")
logger.info(
f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
)
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.connection_id(),
......@@ -191,6 +213,12 @@ def setup_vllm_engine(config, stat_logger=None):
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# Set up consolidator endpoints if KVBM is enabled
consolidator_endpoints = None
if config.has_connector("kvbm"):
consolidator_endpoints = get_consolidator_endpoints(vllm_config)
vllm_config.consolidator_endpoints = consolidator_endpoints
factory = []
if stat_logger:
factory.append(stat_logger)
......@@ -282,9 +310,29 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
runtime, component, engine_client, default_sampling_params
)
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
consolidator_enabled = False
consolidator_port = None
if (
hasattr(vllm_config, "consolidator_endpoints")
and vllm_config.consolidator_endpoints
):
# Extract connect endpoint (third element) for clients to subscribe
# consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
consolidator_enabled = True
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
config,
component,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
consolidator_port=consolidator_port,
)
if kv_publishers:
handler.kv_publishers = kv_publishers
......@@ -368,9 +416,29 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params,
)
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
consolidator_enabled = False
consolidator_port = None
if (
hasattr(vllm_config, "consolidator_endpoints")
and vllm_config.consolidator_endpoints
):
# Extract connect endpoint (third element) for clients to subscribe
# consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
consolidator_enabled = True
# Set up KV event publisher for prefix caching if enabled
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
config,
component,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
consolidator_port=consolidator_port,
)
if kv_publishers:
handler.kv_publishers = kv_publishers
......
......@@ -6,6 +6,7 @@ use anyhow::Result;
use dynamo_llm::block_manager::block::{
data::logical::distributed_leader_worker::DistributedLeaderWorkerResources, locality::Logical,
};
use dynamo_llm::block_manager::kv_consolidator::KvEventConsolidatorConfig;
use dynamo_llm::block_manager::offload::filter::FrequencyFilter;
use dynamo_llm::block_manager::{BasicMetadata, BlockParallelismStrategy};
......@@ -252,6 +253,7 @@ pub struct BlockManagerBuilder {
page_size: usize,
disable_device_pool: bool,
kvbm_metrics: Option<dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics>,
consolidator_config: Option<(String, String)>, // (vllm_endpoint, output_endpoint)
}
impl BlockManagerBuilder {
......@@ -286,6 +288,11 @@ impl BlockManagerBuilder {
self
}
pub fn consolidator_config(mut self, vllm_endpoint: String, output_endpoint: String) -> Self {
self.consolidator_config = Some((vllm_endpoint, output_endpoint));
self
}
/// Async build (call from an async context).
pub async fn build(self) -> Result<BlockManager> {
let worker_id = self.worker_id;
......@@ -356,6 +363,12 @@ impl BlockManagerBuilder {
if let Some(kvbm_metrics) = self.kvbm_metrics {
config_builder = config_builder.kvbm_metrics(Some(kvbm_metrics));
}
if let Some((vllm_ep, output_ep)) = self.consolidator_config {
let consolidator_config = KvEventConsolidatorConfig::new(vllm_ep, output_ep);
config_builder = config_builder.consolidator_config(consolidator_config);
}
let config = config_builder.build()?;
let resources =
......
......@@ -92,6 +92,8 @@ impl KvConnectorLeader {
drt: PyDistributedRuntime,
page_size: usize,
leader_py: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
) -> Self {
tracing::info!(
"KvConnectorLeader initialized with worker_id: {}",
......@@ -114,6 +116,9 @@ impl KvConnectorLeader {
{
let slot_manager_cell = slot_manager_cell.clone();
// Capture consolidator endpoints for the async block
let consolidator_vllm_ep = consolidator_vllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone();
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
......@@ -124,15 +129,27 @@ impl KvConnectorLeader {
return;
}
let block_manager = match BlockManagerBuilder::new()
let mut block_manager_builder = BlockManagerBuilder::new()
.worker_id(0)
.leader(leader_py)
.page_size(page_size)
.disable_device_pool(false)
.kvbm_metrics(kvbm_metrics_clone.clone())
.build()
.await
.kvbm_metrics(kvbm_metrics_clone.clone());
// Add consolidator config if provided
if let (Some(vllm_ep), Some(output_ep)) =
(consolidator_vllm_ep, consolidator_output_ep)
{
tracing::debug!(
"Adding consolidator config to BlockManager: vllm={}, output={}",
vllm_ep,
output_ep
);
block_manager_builder =
block_manager_builder.consolidator_config(vllm_ep, output_ep);
}
let block_manager = match block_manager_builder.build().await {
Ok(bm) => bm,
Err(e) => {
tracing::error!("Failed to build BlockManager: {}", e);
......@@ -547,23 +564,40 @@ pub struct PyKvConnectorLeader {
#[pymethods]
impl PyKvConnectorLeader {
#[new]
#[pyo3(signature = (worker_id, drt, page_size, leader))]
#[pyo3(signature = (worker_id, drt, page_size, leader, consolidator_vllm_endpoint=None, consolidator_output_endpoint=None))]
pub fn new(
worker_id: String,
drt: PyDistributedRuntime,
page_size: usize,
leader: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
) -> Self {
// Initialize logging for the vLLM connector
dynamo_runtime::logging::init();
let enable_kvbm_record = std::env::var("ENABLE_KVBM_RECORD")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
let connector_leader: Box<dyn Leader> = if enable_kvbm_record {
Box::new(recorder::KvConnectorLeaderRecorder::new(
worker_id, drt, page_size, leader,
worker_id,
drt,
page_size,
leader,
consolidator_vllm_endpoint,
consolidator_output_endpoint,
))
} else {
Box::new(KvConnectorLeader::new(worker_id, drt, page_size, leader))
Box::new(KvConnectorLeader::new(
worker_id,
drt,
page_size,
leader,
consolidator_vllm_endpoint,
consolidator_output_endpoint,
))
};
Self { connector_leader }
}
......
......@@ -90,6 +90,8 @@ impl KvConnectorLeaderRecorder {
drt: PyDistributedRuntime,
page_size: usize,
leader_py: PyKvbmLeader,
consolidator_vllm_endpoint: Option<String>,
consolidator_output_endpoint: Option<String>,
) -> Self {
tracing::info!(
"KvConnectorLeaderRecorder initialized with worker_id: {}",
......@@ -130,6 +132,9 @@ impl KvConnectorLeaderRecorder {
{
let slot_manager_cell = slot_manager_cell.clone();
// Capture consolidator endpoints for the async block
let consolidator_vllm_ep = consolidator_vllm_endpoint.clone();
let consolidator_output_ep = consolidator_output_endpoint.clone();
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
......@@ -140,15 +145,22 @@ impl KvConnectorLeaderRecorder {
return;
}
let block_manager = match BlockManagerBuilder::new()
let mut block_manager_builder = BlockManagerBuilder::new()
.worker_id(0)
.leader(leader_py)
.page_size(page_size)
.disable_device_pool(false)
.kvbm_metrics(kvbm_metrics_clone.clone())
.build()
.await
.kvbm_metrics(kvbm_metrics_clone.clone());
// Add consolidator config if provided
if let (Some(vllm_ep), Some(output_ep)) =
(consolidator_vllm_ep, consolidator_output_ep)
{
block_manager_builder =
block_manager_builder.consolidator_config(vllm_ep, output_ep);
}
let block_manager = match block_manager_builder.build().await {
Ok(bm) => bm,
Err(e) => {
tracing::error!("Failed to build BlockManager: {}", e);
......
......@@ -63,8 +63,46 @@ class KvConnectorLeader:
leader = KvbmLeader(world_size, drt=self.drt)
print(f"KvConnectorLeader initialized with engine_id: {engine_id}")
# Get kv event consolidator endpoints from vllm_config (pre-computed in main.py)
consolidator_vllm_endpoint = None
consolidator_output_endpoint = None
self._consolidator_output_port = None
if (
hasattr(vllm_config, "consolidator_endpoints")
and vllm_config.consolidator_endpoints
):
# Unpack all three endpoints
# [0]: vllm_endpoint (for consolidator to subscribe to vLLM)
# [1]: output_bind_endpoint (for consolidator to bind/publish)
# [2]: output_connect_endpoint (for clients to connect)
(
consolidator_vllm_endpoint,
consolidator_output_endpoint,
_consolidator_output_connect_endpoint, # Not needed here
) = vllm_config.consolidator_endpoints
self._consolidator_output_port = int(
consolidator_output_endpoint.split(":")[-1]
)
# Pass endpoints to Rust
self._connector = RustKvConnectorLeader(
engine_id,
self.drt,
vllm_config.cache_config.block_size,
leader,
consolidator_vllm_endpoint=consolidator_vllm_endpoint,
consolidator_output_endpoint=consolidator_output_endpoint,
)
else:
# No kv event consolidator - pass None to Rust
self._connector = RustKvConnectorLeader(
engine_id, self.drt, vllm_config.cache_config.block_size, leader
engine_id,
self.drt,
vllm_config.cache_config.block_size,
leader,
consolidator_vllm_endpoint=None,
consolidator_output_endpoint=None,
)
# KV Connector
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Helper functions for KV Event Consolidator configuration.
"""
import logging
import os
from typing import Optional, Tuple
from vllm.distributed.kv_events import ZmqEventPublisher
logger = logging.getLogger(__name__)
def is_truthy(val: str) -> bool:
"""
Check if a string represents a truthy value.
Truthy values: "1", "true", "on", "yes" (case-insensitive)
Args:
val: The string value to check
Returns:
True if the value is truthy, False otherwise
"""
return val.lower() in ("1", "true", "on", "yes")
def should_enable_consolidator(vllm_config) -> bool:
"""
Determine if the KV Event Consolidator should be enabled based on vLLM config.
The consolidator can be controlled via the DYN_KVBM_KV_EVENTS_ENABLE_CONSOLIDATOR environment variable:
- Set to truthy values ("1", "true", "on", "yes") to enable (default)
- Set to any other value to disable
- If not set, defaults to enabled and auto-detects based on KVBM connector and prefix caching settings
Args:
vllm_config: The vLLM VllmConfig object
Returns:
True if consolidator should be enabled, False otherwise
"""
# Check environment variable override
env_override = os.getenv("DYN_KVBM_KV_EVENTS_ENABLE_CONSOLIDATOR", "true")
if not is_truthy(env_override):
logger.info(
"KV Event Consolidator disabled via DYN_KVBM_KV_EVENTS_ENABLE_CONSOLIDATOR environment variable"
)
return False
# Auto-detection: Check if KVBM connector is in use
if (
not hasattr(vllm_config, "kv_transfer_config")
or vllm_config.kv_transfer_config is None
):
logger.warning(
"KV Event Consolidator is not enabled due to missing kv_transfer_config"
)
return False
kv_transfer_config = vllm_config.kv_transfer_config
# Check if DynamoConnector is present
connector_name = getattr(kv_transfer_config, "kv_connector", None)
is_dynamo_connector = connector_name == "DynamoConnector"
# For multi-connector (PdConnector), check if DynamoConnector is in the list
if connector_name == "PdConnector":
extra_config = getattr(kv_transfer_config, "kv_connector_extra_config", {})
connectors = extra_config.get("connectors", [])
is_dynamo_connector = any(
conn.get("kv_connector") == "DynamoConnector" for conn in connectors
)
if not is_dynamo_connector:
logger.warning(
f"KV Event Consolidator is not enabled: DynamoConnector (KVBM) not found (current connector: {connector_name})"
)
return False
# Check if prefix caching is enabled (required for KV events)
if not vllm_config.cache_config.enable_prefix_caching:
logger.warning(
"KVBM connector requires prefix caching to be enabled for KV event consolidation. "
"KV Event Consolidator is not enabled."
)
return False
logger.info(
"KV Event Consolidator auto-enabled (KVBM connector + prefix caching detected)"
)
return True
def get_consolidator_endpoints(vllm_config) -> Optional[Tuple[str, str, str]]:
"""
Get consolidator endpoints from vLLM config.
Args:
vllm_config: The vLLM VllmConfig object
Returns:
Tuple of (vllm_endpoint, output_bind_endpoint, output_connect_endpoint) if consolidator should be enabled,
where:
- vllm_endpoint: ZMQ endpoint for consolidator to subscribe to vLLM events
- output_bind_endpoint: ZMQ endpoint for consolidator to bind and publish (tcp://0.0.0.0:PORT)
- output_connect_endpoint: ZMQ endpoint for clients to connect (tcp://127.0.0.1:PORT)
None if consolidator should not be enabled
"""
if not should_enable_consolidator(vllm_config):
return None
# Get vLLM's ZMQ endpoint
# TODO: Data parallelism is not yet supported for consolidator
# Currently assumes data_parallel_rank=0
base_endpoint = vllm_config.kv_events_config.endpoint
data_parallel_rank = (
getattr(vllm_config.parallel_config, "data_parallel_rank", 0) or 0
)
if data_parallel_rank != 0:
logger.warning(
f"KV Event Consolidator does not yet support data_parallel_rank={data_parallel_rank}. "
"Only rank 0 is supported. Proceeding with rank 0."
)
data_parallel_rank = 0
vllm_endpoint = ZmqEventPublisher.offset_endpoint_port(
base_endpoint,
data_parallel_rank=data_parallel_rank,
).replace("*", "127.0.0.1")
# Derive consolidator port deterministically from KVBM leader ZMQ pub port
# Default value (56001) aligns with Rust constant DEFAULT_LEADER_ZMQ_PUB_PORT defined in:
# dynamo/lib/bindings/python/rust/llm/block_manager/distributed/utils.rs
kvbm_pub_port_str = os.getenv("DYN_KVBM_LEADER_ZMQ_PUB_PORT", "56001")
kvbm_pub_port = int(kvbm_pub_port_str)
# Use 1000 offset to keep ports close together
# Example: 56001 -> 57001
consolidator_port_offset = 1000
output_port = kvbm_pub_port + consolidator_port_offset
# Validate the derived port is within valid range
if output_port > 65535:
raise ValueError(
f"Derived consolidator port {output_port} exceeds maximum (65535). "
f"KVBM port {kvbm_pub_port} is too high. Use a lower base port."
)
# Build bind and connect endpoints
# Consolidator binds to 0.0.0.0 (all interfaces), clients connect to 127.0.0.1
output_bind_endpoint = f"tcp://0.0.0.0:{output_port}"
output_connect_endpoint = f"tcp://127.0.0.1:{output_port}"
logger.info(
f"Consolidator endpoints: vllm={vllm_endpoint}, "
f"output_bind={output_bind_endpoint}, output_connect={output_connect_endpoint} "
f"(derived from KVBM port {kvbm_pub_port})"
)
# Return both bind and connect endpoints as a tuple
# First element is vllm_endpoint (for consolidator to subscribe)
# Second element is output_bind_endpoint (for consolidator to bind/publish)
# Third element is output_connect_endpoint (for clients to connect)
return vllm_endpoint, output_bind_endpoint, output_connect_endpoint
......@@ -14,6 +14,7 @@ pub mod block;
pub mod connector;
pub mod distributed;
pub mod events;
pub mod kv_consolidator;
pub mod layout;
pub mod metrics_kvbm;
pub mod numa_allocator;
......
......@@ -230,6 +230,16 @@ pub struct RegistrationHandle {
}
impl RegistrationHandle {
/// Returns the block size (number of tokens in the block)
pub fn block_size(&self) -> usize {
self.token_block.block_size()
}
/// Returns a reference to the tokens in this block
pub fn tokens(&self) -> &crate::tokens::Tokens {
self.token_block.tokens()
}
fn from_token_block(
token_block: &TokenBlock,
release_manager: Arc<dyn EventReleaseManager>,
......
......@@ -203,6 +203,15 @@ pub struct KvBlockManagerConfig {
/// Optional KVBM-level metrics for tracking offload/onboard operations
#[builder(default)]
pub kvbm_metrics: Option<crate::block_manager::metrics_kvbm::KvbmMetrics>,
/// Optional KV Event Consolidator Configuration
///
/// If provided, KVBM will create a KV Event Consolidator that deduplicates
/// KV cache events from vLLM (G1) and KVBM (G2/G3) before sending to the router.
/// This is used when `--connector kvbm` is enabled with prefix caching.
#[builder(default, setter(strip_option))]
pub consolidator_config:
Option<crate::block_manager::kv_consolidator::KvEventConsolidatorConfig>,
}
impl KvBlockManagerConfig {
......
......@@ -5,6 +5,9 @@ use std::sync::Arc;
use super::block::registry::RegistrationHandle;
use crate::block_manager::kv_consolidator::EventSource;
use crate::block_manager::kv_consolidator::KvEventConsolidator;
/// The [EventManager] is not responsible for managing the history of the blocks, nor what
/// events have been published.
///
......@@ -141,6 +144,148 @@ impl EventReleaseManager for NullEventManager {
fn block_release(&self, _registration_handle: &RegistrationHandle) {}
}
/// Event manager that sends KVBM events to the kv event consolidator
pub struct DynamoEventManager {
consolidator_handle: Arc<crate::block_manager::kv_consolidator::KvEventConsolidatorHandle>,
#[allow(dead_code)]
_consolidator: Option<Arc<crate::block_manager::kv_consolidator::KvEventConsolidator>>,
}
impl DynamoEventManager {
/// Create a new DynamoEventManager with a consolidator handle
pub fn new(
consolidator_handle: Arc<crate::block_manager::kv_consolidator::KvEventConsolidatorHandle>,
) -> Arc<Self> {
Arc::new(Self {
consolidator_handle,
_consolidator: None,
})
}
/// Create a new DynamoEventManager with kv event consolidator configuration
///
/// This creates and manages the kv event consolidator internally.
/// The kv event consolidator will be started asynchronously.
pub async fn new_with_config(
config: crate::block_manager::kv_consolidator::KvEventConsolidatorConfig,
) -> anyhow::Result<Arc<Self>> {
let mut kv_event_consolidator = KvEventConsolidator::new(config)?;
kv_event_consolidator.start().await?;
let handle = kv_event_consolidator.get_handle();
Ok(Arc::new(Self {
consolidator_handle: Arc::new(handle),
_consolidator: Some(Arc::new(kv_event_consolidator)),
}))
}
/// Send store events to the kv event consolidator
///
/// Called when KVBM registers/stores blocks. Sends events to the kv event consolidator
/// which will deduplicate them with vLLM events.
///
fn publish_store_events(&self, handles: Vec<Arc<RegistrationHandle>>) {
if handles.is_empty() {
return;
}
tracing::debug!(
"DynamoEventManager::publish_store_events called with {} blocks",
handles.len()
);
// Send each block to the consolidator
let kv_event_consolidator = self.consolidator_handle.clone();
if let Ok(rt) = tokio::runtime::Handle::try_current() {
rt.spawn(async move {
for handle in handles {
// Extract block metadata from RegistrationHandle
let block_hash = handle.sequence_hash().to_string();
let parent_hash = handle.parent_sequence_hash().map(|h| h.to_string());
// Extract block_size and tokens from RegistrationHandle
let block_size = handle.block_size(); // usize
let tokens: Vec<u32> = handle.tokens().iter().copied().collect();
tracing::debug!(
"DynamoEventManager sending store event to kv event consolidator: block_hash={}, block_size={}, tokens={}",
block_hash,
block_size,
tokens.len()
);
// Send to consolidator with EventSource::Kvbm
kv_event_consolidator
.handle_store(
block_hash,
EventSource::Kvbm,
tokens,
parent_hash,
block_size,
None, // lora_id
None, // tier
None, // data_parallel_rank
)
.await;
}
});
} else {
tracing::error!(
"No Tokio runtime in context; dropping store events for {} blocks",
handles.len()
);
}
}
/// Send remove event to the kv event consolidator
///
/// Called when a RegistrationHandle is dropped (block evicted from KVBM).
fn publish_remove_event(&self, registration_handle: &RegistrationHandle) {
let block_hash = registration_handle.sequence_hash().to_string();
tracing::debug!(
"DynamoEventManager::publish_remove_event called: block_hash={}",
block_hash
);
let kv_event_consolidator = self.consolidator_handle.clone();
if let Ok(rt) = tokio::runtime::Handle::try_current() {
rt.spawn(async move {
kv_event_consolidator
.handle_remove(&block_hash, EventSource::Kvbm)
.await;
});
} else {
tracing::error!(
"No Tokio runtime in context; dropping remove event for block {}",
block_hash
);
}
}
}
impl std::fmt::Debug for DynamoEventManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "DynamoEventManager(kv_event_consolidator)")
}
}
impl EventManager for DynamoEventManager {}
impl EventPublisher for DynamoEventManager {
fn publish(&self, handles: Vec<Arc<RegistrationHandle>>) {
self.publish_store_events(handles);
}
}
impl EventReleaseManager for DynamoEventManager {
fn block_release(&self, registration_handle: &RegistrationHandle) {
self.publish_remove_event(registration_handle);
}
}
#[cfg(test)]
pub mod tests {
use crate::tokens::SequenceHash;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::{
kv_router::{
indexer::RouterEvent,
protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash,
},
KV_EVENT_SUBJECT,
},
tokens::BlockHash,
};
use derive_getters::{Dissolve, Getters};
use dynamo_runtime::traits::events::EventPublisher;
use dynamo_runtime::{
component::{Component, Namespace},
raise, Result,
};
use std::sync::Arc;
use tokio::sync::mpsc;
pub enum DynamoPublisher {
Component(Component),
Namespace(Namespace),
}
impl DynamoPublisher {
pub async fn publish(&self, event: RouterEvent) -> Result<()> {
match self {
DynamoPublisher::Component(component) => {
component.publish(KV_EVENT_SUBJECT, &event).await
}
DynamoPublisher::Namespace(namespace) => {
namespace.publish(KV_EVENT_SUBJECT, &event).await
}
}
}
}
struct EventChannel {
tx: mpsc::UnboundedSender<Event>,
}
impl EventReleaseManager for EventChannel {
// Generalize sequence_hash
fn block_release(&self, sequence_hash: SequenceHash) {
if self.tx.send(Event::RemoveSingle(sequence_hash)).is_err() {
tracing::warn!("Failed to send remove block event");
}
}
}
pub struct NatsEventManager {
event_channel: Arc<EventChannel>,
}
impl NatsEventManager {
// todo - generalize identifier
pub async fn new(publisher: DynamoPublisher, worker_identifier: u64) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
let state = NatsEventsManagerState {
rx,
publisher,
worker_identifier,
};
tokio::spawn(progress_engine(state));
Self {
event_channel: Arc::new(EventChannel { tx }),
}
}
}
impl std::fmt::Debug for NatsEventManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NatsEventManager")
}
}
impl EventManager for NatsEventManager {
fn register_block(&self, token_block: &TokenBlock) -> Result<RegistrationHandle> {
let event = Event::StoreSingle(RegisterBlockEvent {
block_hash: LocalBlockHash(token_block.block_hash()),
sequence_hash: ExternalSequenceBlockHash(token_block.sequence_hash()),
parent_hash: token_block
.parent_sequence_hash()
.map(ExternalSequenceBlockHash),
});
if self.event_channel.tx.send(event).is_err() {
tracing::warn!("Failed to send store block event");
raise!("Failed to send store block event");
}
Ok(RegistrationHandle {
sequence_hash: token_block.sequence_hash(),
release_manager: Some(self.event_channel.clone()),
})
}
fn register_blocks(&self, token_blocks: &[TokenBlock]) -> Result<Vec<RegistrationHandle>> {
let event = Event::StoreMultiple(RegisterBlocksEvent {
hashes: token_blocks
.iter()
.map(|block| {
(
LocalBlockHash(block.block_hash()),
ExternalSequenceBlockHash(block.sequence_hash()),
)
})
.collect(),
parent_hash: token_blocks
.first()
.and_then(|block| block.parent_sequence_hash().map(ExternalSequenceBlockHash)),
});
let handles = token_blocks
.iter()
.map(|block| RegistrationHandle {
sequence_hash: block.sequence_hash(),
release_manager: Some(self.event_channel.clone()),
})
.collect();
if self.event_channel.tx.send(event).is_err() {
tracing::warn!("Failed to send store block event");
raise!("Failed to send store block event");
}
Ok(handles)
}
}
#[derive(Dissolve)]
struct NatsEventsManagerState {
rx: mpsc::UnboundedReceiver<Event>,
publisher: DynamoPublisher,
worker_identifier: WorkerIdentifier,
}
async fn progress_engine(state: NatsEventsManagerState) {
let (mut rx, publisher, worker_identifier) = state.dissolve();
let mut event_id = 0;
while let Some(event) = rx.recv().await {
match event {
Event::StoreSingle(event) => {
let store_data = KvCacheStoreData {
blocks: vec![KvCacheStoredBlockData {
block_hash: event.sequence_hash,
tokens_hash: event.block_hash,
}],
parent_hash: event.parent_hash,
};
let data = KvCacheEventData::Stored(store_data);
let event = KvCacheEvent { event_id, data };
let event = RouterEvent::new(worker_identifier as i64, event);
if publisher.publish(event).await.is_err() {
tracing::warn!("Failed to publish store event");
}
}
Event::StoreMultiple(event) => {
let store_data = KvCacheStoreData {
blocks: event
.hashes
.iter()
.map(|(local_hash, external_hash)| KvCacheStoredBlockData {
block_hash: *external_hash,
tokens_hash: *local_hash,
})
.collect(),
parent_hash: event.parent_hash,
};
let data = KvCacheEventData::Stored(store_data);
let event = KvCacheEvent { event_id, data };
let event = RouterEvent::new(worker_identifier as i64, event);
if publisher.publish(event).await.is_err() {
tracing::warn!("Failed to publish store event");
}
}
Event::RemoveSingle(sequence_hash) => {
let remove_data = KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(sequence_hash)],
};
let data = KvCacheEventData::Removed(remove_data);
let event = KvCacheEvent { event_id, data };
let event = RouterEvent::new(worker_identifier as i64, event);
if publisher.publish(event).await.is_err() {
tracing::warn!("Failed to publish remove event");
}
}
}
event_id += 1;
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Configuration for the KV Event Consolidator
use serde::{Deserialize, Serialize};
/// Configuration for the KV Event Consolidator
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvEventConsolidatorConfig {
/// ZMQ endpoint to subscribe to vLLM events (e.g., "tcp://localhost:5557")
pub vllm_event_endpoint: String,
/// ZMQ endpoint to publish consolidated events (e.g., "tcp://*:5558")
pub consolidated_event_endpoint: String,
}
impl Default for KvEventConsolidatorConfig {
fn default() -> Self {
Self {
vllm_event_endpoint: "tcp://localhost:5557".to_string(),
consolidated_event_endpoint: "tcp://*:5558".to_string(),
}
}
}
impl KvEventConsolidatorConfig {
pub fn new(vllm_event_endpoint: String, consolidated_event_endpoint: String) -> Self {
Self {
vllm_event_endpoint,
consolidated_event_endpoint,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! KV Event Consolidator
//!
//! This module consolidates kv events from multiple sources (vLLM's G1 events
//! and KVBM's G2/G3 events) before publishing them to the router.
pub mod config;
pub mod publisher;
pub mod subscriber;
pub mod tracker;
pub use config::KvEventConsolidatorConfig;
pub use publisher::KvEventConsolidatorPublisher;
pub use tracker::{CacheStatusTracker, EventSource, StorageTier};
use anyhow::Result;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use subscriber::start_simple_zmq_listener;
/// Handle for KVBM to send G2/G3 events directly to the KV Event Consolidator
#[derive(Clone, Debug)]
pub struct KvEventConsolidatorHandle {
pub(crate) tracker: Arc<RwLock<CacheStatusTracker>>,
}
impl KvEventConsolidatorHandle {
/// Send a block store event to the KV Event Consolidator
///
/// This is called by KVBM when a block is stored in G2 or G3.
#[allow(clippy::too_many_arguments)]
pub async fn handle_store(
&self,
block_hash: String,
source: EventSource,
token_ids: Vec<u32>,
parent_hash: Option<String>,
block_size: usize,
lora_id: Option<u64>,
tier: Option<StorageTier>,
data_parallel_rank: Option<i32>,
) {
let mut tracker = self.tracker.write().await;
tracker.handle_store(
block_hash,
source,
token_ids,
parent_hash,
block_size,
lora_id.map(|id| id as i32),
tier,
data_parallel_rank,
);
}
/// Send a block remove event to the KV Event Consolidator
///
/// This is called by KVBM when a block is removed from G2 or G3.
pub async fn handle_remove(&self, block_hash: &str, source: EventSource) {
let mut tracker = self.tracker.write().await;
tracker.handle_remove(block_hash, source);
}
/// Clear all blocks from the KV Event Consolidator
///
/// This is called by KVBM when all blocks should be evicted.
pub async fn handle_clear_all(&self) {
let mut tracker = self.tracker.write().await;
tracker.handle_clear_all();
}
}
/// The main KV Event Consolidator that manages the event flow
pub struct KvEventConsolidator {
config: KvEventConsolidatorConfig,
tracker: Arc<RwLock<CacheStatusTracker>>,
subscriber_handle: Option<JoinHandle<()>>,
cancellation_token: CancellationToken,
publisher: Option<KvEventConsolidatorPublisher>,
}
impl KvEventConsolidator {
/// Create a new KV Event Consolidator
pub fn new(config: KvEventConsolidatorConfig) -> Result<Self> {
let tracker = Arc::new(RwLock::new(CacheStatusTracker::new()));
let cancellation_token = CancellationToken::new();
Ok(Self {
config,
tracker,
subscriber_handle: None,
cancellation_token,
publisher: None,
})
}
/// Start the KV Event Consolidator
pub async fn start(&mut self) -> Result<()> {
tracing::info!(
"Starting KV Event Consolidator: subscribe from {}, publish to {}",
self.config.vllm_event_endpoint,
self.config.consolidated_event_endpoint
);
// Start the publisher first
let publisher = KvEventConsolidatorPublisher::new(
&self.config.consolidated_event_endpoint,
self.tracker.clone(),
)?;
self.publisher = Some(publisher);
tracing::info!("Waiting for downstream subscribers to connect...");
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Start the subscriber (connects to vLLM's publisher)
let handle = start_simple_zmq_listener(
self.config.vllm_event_endpoint.clone(),
self.tracker.clone(),
self.cancellation_token.clone(),
)
.await?;
self.subscriber_handle = Some(handle);
tracing::info!("KV Event Consolidator fully started and ready");
Ok(())
}
/// Shutdown the KV Event Consolidator
pub async fn shutdown(&mut self) -> Result<()> {
tracing::info!("Shutting down KV Event Consolidator");
// Cancel the ZMQ listener
self.cancellation_token.cancel();
// Wait for adapter task to finish
if let Some(handle) = self.subscriber_handle.take() {
handle.abort();
let _ = handle.await;
}
if let Some(publisher) = self.publisher.take() {
publisher.shutdown().await?;
}
Ok(())
}
/// Get a reference to the cache status tracker (for debugging/metrics)
pub fn tracker(&self) -> Arc<RwLock<CacheStatusTracker>> {
self.tracker.clone()
}
/// Get a handle that KVBM can use to send G2/G3 kv events directly
pub fn get_handle(&self) -> KvEventConsolidatorHandle {
KvEventConsolidatorHandle {
tracker: self.tracker.clone(),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! ZMQ Publisher for KV Events Consolidator
//!
//! Publishes consolidated KV cache events to the router using the same format as vLLM.
use anyhow::{Context, Result};
use bytes::Bytes;
use rmp_serde::Serializer;
use serde::Serialize;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use zeromq::{PubSocket, Socket, SocketSend};
use super::tracker::{CacheStatusTracker, ConsolidatedEvent};
/// Event batch structure matching vLLM's format (array_like=True)
/// Format: [timestamp, [events], data_parallel_rank]
///
/// Note: This uses a tuple struct to serialize as an array [ts, events, rank]
/// rather than an object {"ts": ..., "events": ..., "rank": ...} for vLLM compatibility.
#[derive(Debug, Serialize)]
struct EventBatch(
f64, // ts
Vec<Event>, // events
Option<i32>, // data_parallel_rank
);
/// Event types matching vLLM's format
/// Note: block_hashes are u64 to match vLLM's ExternalBlockHash type
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum Event {
#[serde(rename = "BlockStored")]
BlockStored {
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<i32>,
block_size: i32,
lora_id: Option<i32>,
},
#[serde(rename = "BlockRemoved")]
BlockRemoved { block_hashes: Vec<u64> },
#[serde(rename = "AllBlocksCleared")]
AllBlocksCleared {},
}
impl Event {
/// Convert from ConsolidatedEvent to vLLM Event format
/// Parses string block hashes back to u64 for router compatibility
/// Note: source field is kept in ConsolidatedEvent for internal logging but not sent to router
///
/// Returns an error if block hash parsing fails to prevent sending corrupted events to the router
fn from_consolidated(event: ConsolidatedEvent) -> Result<Self> {
match event {
ConsolidatedEvent::Store {
block_hash,
parent_hash,
token_ids,
block_size,
lora_id,
source: _, // Source used for logging only, not sent to router
} => {
// Parse block hash - fail if invalid to prevent corruption
let parsed_hash = block_hash
.parse::<u64>()
.with_context(|| format!("Failed to parse block_hash: {}", block_hash))?;
// Parse parent hash if present - fail if invalid
let parsed_parent = parent_hash
.map(|h| {
h.parse::<u64>()
.with_context(|| format!("Failed to parse parent_hash: {}", h))
})
.transpose()?;
// Convert u32 token_ids to i32 for vLLM compatibility
// Token IDs should never exceed i32::MAX in practice, but we handle it gracefully
let token_ids_i32: Vec<i32> = token_ids
.into_iter()
.map(|t| {
i32::try_from(t).unwrap_or_else(|_| {
tracing::warn!("Token ID {} exceeds i32::MAX, clamping to i32::MAX", t);
i32::MAX
})
})
.collect();
// Convert usize block_size to i32 for vLLM compatibility
let block_size_i32 = i32::try_from(block_size).unwrap_or_else(|_| {
tracing::warn!(
"Block size {} exceeds i32::MAX, clamping to i32::MAX",
block_size
);
i32::MAX
});
Ok(Event::BlockStored {
block_hashes: vec![parsed_hash],
parent_block_hash: parsed_parent,
token_ids: token_ids_i32,
block_size: block_size_i32,
lora_id,
})
}
ConsolidatedEvent::Remove {
block_hash,
source: _,
} => {
// Parse block hash - fail if invalid to prevent corruption
let parsed_hash = block_hash.parse::<u64>().with_context(|| {
format!("Failed to parse block_hash for removal: {}", block_hash)
})?;
Ok(Event::BlockRemoved {
block_hashes: vec![parsed_hash],
})
}
ConsolidatedEvent::ClearAll => Ok(Event::AllBlocksCleared {}),
}
}
}
/// ZMQ Publisher for consolidated events
pub struct KvEventConsolidatorPublisher {
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
sequence: Arc<AtomicU64>,
task_handle: Option<JoinHandle<()>>,
}
impl KvEventConsolidatorPublisher {
/// Create a new publisher
pub fn new(endpoint: &str, tracker: Arc<RwLock<CacheStatusTracker>>) -> Result<Self> {
let endpoint = endpoint.to_string();
let sequence = Arc::new(AtomicU64::new(0));
let publisher = Self {
endpoint: endpoint.clone(),
tracker: tracker.clone(),
sequence: sequence.clone(),
task_handle: None,
};
// Start the publisher task
let handle = tokio::spawn(async move {
if let Err(e) = Self::run_publisher_loop(endpoint, tracker, sequence).await {
// Bind failures and other critical errors should crash the process
panic!("Publisher task failed: {}", e);
}
});
Ok(Self {
endpoint: publisher.endpoint,
tracker: publisher.tracker,
sequence: publisher.sequence,
task_handle: Some(handle),
})
}
/// Stop the publisher task
pub async fn shutdown(self) -> Result<()> {
if let Some(handle) = self.task_handle {
handle.abort();
let _ = handle.await;
}
Ok(())
}
/// Main publisher loop
async fn run_publisher_loop(
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
sequence: Arc<AtomicU64>,
) -> Result<()> {
tracing::info!("Starting consolidated event publisher on {}", endpoint);
// Create ZMQ PUB socket and bind
let mut socket = PubSocket::new();
socket
.bind(&endpoint)
.await
.with_context(|| format!("Failed to bind publisher to {}", endpoint))?;
tracing::info!("Publisher bound to {}", endpoint);
// Publish loop - check for events every 50ms
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(50));
loop {
interval.tick().await;
// Drain events from tracker
let events = {
let mut tracker_guard = tracker.write().await;
tracker_guard.drain_events()
};
if events.is_empty() {
continue;
}
tracing::debug!(
"Publishing {} consolidated event(s) to router",
events.len()
);
// Convert to vLLM format, filtering out events with invalid hashes
let vllm_events: Vec<Event> = events
.into_iter()
.filter_map(|event| match Event::from_consolidated(event) {
Ok(e) => Some(e),
Err(err) => {
tracing::error!("Failed to convert consolidated event, skipping: {}", err);
None
}
})
.collect();
// Skip publishing if all events were invalid
if vllm_events.is_empty() {
tracing::warn!("All consolidated events failed validation, skipping publish");
continue;
}
let num_events = vllm_events.len(); // Save length before move
let batch = EventBatch(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs_f64(), // ts
vllm_events, // events
Some(0), // data_parallel_rank (default)
);
// Serialize to msgpack
let mut payload = Vec::new();
batch
.serialize(&mut Serializer::new(&mut payload))
.context("Failed to serialize event batch")?;
// Get sequence number
let seq = sequence.fetch_add(1, Ordering::SeqCst);
let seq_bytes = seq.to_be_bytes();
// Send multipart message: [topic, sequence, payload]
// Empty topic means all subscribers receive it
let frames = vec![
Bytes::from(""),
Bytes::from(seq_bytes.to_vec()),
Bytes::from(payload),
];
let msg = match zeromq::ZmqMessage::try_from(frames) {
Ok(m) => m,
Err(e) => {
tracing::error!("Failed to create multipart ZMQ message: {:?}", e);
continue;
}
};
if let Err(e) = socket.send(msg).await {
tracing::error!("Failed to send consolidated events: {}", e);
} else {
tracing::debug!(
"Published batch with {} event(s) to router (seq={})",
num_events,
seq
);
}
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Simple ZMQ Subscriber for vLLM KV Events
//!
//! This is a simplified subscriber that deserializes raw vLLM events.
use anyhow::{Context, Result};
use rmp_serde::Deserializer;
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket};
use super::tracker::{CacheStatusTracker, StorageTier};
/// Event batch received from vLLM (array format)
/// Format: [timestamp, [events], data_parallel_rank]
///
/// Note: This uses a tuple struct to deserialize from array [ts, events, rank]
/// rather than an object {"ts": ..., "events": ..., "rank": ...} for vLLM compatibility.
#[derive(Debug, Deserialize)]
struct VllmEventBatch(
f64, // ts
Vec<VllmRawEvent>, // events
Option<i32>, // data_parallel_rank
);
impl VllmEventBatch {
fn ts(&self) -> f64 {
self.0
}
fn events(&self) -> &Vec<VllmRawEvent> {
&self.1
}
fn data_parallel_rank(&self) -> Option<i32> {
self.2
}
}
/// Block hash can be either an integer or a string (bytes hex-encoded)
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum BlockHash {
Int(u64),
Str(String),
}
impl std::fmt::Display for BlockHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BlockHash::Int(n) => write!(f, "{}", n),
BlockHash::Str(s) => write!(f, "{}", s),
}
}
}
/// Raw vLLM event format (preserves all data including token_ids)
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
enum VllmRawEvent {
#[serde(rename = "BlockStored")]
BlockStored {
block_hashes: Vec<BlockHash>,
parent_block_hash: Option<BlockHash>,
token_ids: Vec<i32>,
block_size: i32,
lora_id: Option<i32>,
#[serde(default)]
medium: Option<String>,
},
#[serde(rename = "BlockRemoved")]
BlockRemoved {
block_hashes: Vec<BlockHash>,
#[serde(default)]
medium: Option<String>,
},
#[serde(rename = "AllBlocksCleared")]
AllBlocksCleared {},
}
/// Start ZMQ listener and process events into tracker
pub async fn start_simple_zmq_listener(
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
cancellation_token: CancellationToken,
) -> Result<JoinHandle<()>> {
let handle = tokio::spawn(async move {
if let Err(e) = run_listener_loop(endpoint, tracker, cancellation_token).await {
tracing::error!("ZMQ listener task failed: {}", e);
}
});
Ok(handle)
}
async fn run_listener_loop(
endpoint: String,
tracker: Arc<RwLock<CacheStatusTracker>>,
cancellation_token: CancellationToken,
) -> Result<()> {
tracing::info!(
"KV event consolidator ZMQ listener connecting to {}",
endpoint
);
let mut socket = SubSocket::new();
socket
.connect(&endpoint)
.await
.context("Failed to connect to ZMQ endpoint")?;
socket
.subscribe("")
.await
.context("Failed to subscribe to ZMQ topics")?;
tracing::info!(
"KV event consolidator ZMQ listener successfully connected to {}",
endpoint
);
loop {
tokio::select! {
biased;
_ = cancellation_token.cancelled() => {
tracing::debug!("ZMQ listener received cancellation signal");
break;
}
msg_result = socket.recv() => {
let Ok(msg) = msg_result else {
tracing::warn!("Error receiving ZMQ message: {:?}", msg_result.unwrap_err());
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
continue;
};
// Parse multipart message: supports both formats
// - 2 frames: [topic, payload]
// - 3 frames: [topic, sequence, payload]
let frames: Vec<Vec<u8>> = msg.into_vec().into_iter().map(|f| f.to_vec()).collect();
let payload = match frames.len() {
2 => &frames[1], // [topic, payload]
3 => &frames[2], // [topic, sequence, payload]
_ => {
tracing::warn!("Unexpected frame count: {} (expected 2 or 3)", frames.len());
continue;
}
};
// Deserialize event batch
let mut deserializer = Deserializer::new(&payload[..]);
let batch: VllmEventBatch = match Deserialize::deserialize(&mut deserializer) {
Ok(b) => b,
Err(e) => {
tracing::warn!("Failed to deserialize event batch: {}", e);
continue;
}
};
let dp_rank = batch.data_parallel_rank();
tracing::debug!(
"Consolidator received event batch with {} events (ts={:.2}, dp_rank={:?})",
batch.events().len(),
batch.ts(),
dp_rank
);
// Process events
let mut tracker_guard = tracker.write().await;
for event in batch.events() {
process_event(&mut tracker_guard, event.clone(), dp_rank);
}
}
}
}
Ok(())
}
fn process_event(
tracker: &mut CacheStatusTracker,
event: VllmRawEvent,
data_parallel_rank: Option<i32>,
) {
match event {
VllmRawEvent::BlockStored {
block_hashes,
parent_block_hash,
token_ids,
block_size,
lora_id,
medium,
} => {
let storage_tier = medium
.as_ref()
.and_then(|m| StorageTier::from_vllm_medium(m))
.unwrap_or(StorageTier::Device);
tracing::debug!(
"Processing BlockStored: {} blocks, tier={:?}, tokens={}, block_size={}, parent={:?}, dp_rank={:?}",
block_hashes.len(),
storage_tier,
token_ids.len(),
block_size,
parent_block_hash,
data_parallel_rank
);
// Convert block_size from i32 to usize for chunking
// SAFETY: Must validate block_size > 0 to prevent panic in chunks()
let block_size_usize = match usize::try_from(block_size) {
Ok(size) if size > 0 => size,
_ => {
tracing::warn!(
"Invalid block_size {} (must be positive), skipping event to avoid chunks() panic",
block_size
);
return;
}
};
// Convert token_ids from i32 to u32 and split into chunks
let token_ids_u32: Vec<u32> = token_ids
.into_iter()
.filter_map(|t| {
u32::try_from(t).ok().or_else(|| {
tracing::warn!("Invalid token ID {}, skipping", t);
None
})
})
.collect();
let token_chunks: Vec<Vec<u32>> = token_ids_u32
.chunks(block_size_usize)
.map(|chunk| chunk.to_vec())
.collect();
if token_chunks.len() != block_hashes.len() {
tracing::warn!(
"Token chunks ({}) don't match block hashes ({}), skipping event",
token_chunks.len(),
block_hashes.len()
);
return;
}
// Process each block with its corresponding token chunk
// For batches, chain the blocks: each block's parent is the previous block in the batch
let mut current_parent = parent_block_hash.as_ref().map(|h| h.to_string());
for (i, block_hash) in block_hashes.iter().enumerate() {
let block_tokens = token_chunks[i].clone();
tracker.handle_store(
block_hash.to_string(),
crate::block_manager::kv_consolidator::EventSource::Vllm,
block_tokens,
current_parent.clone(),
block_size_usize,
lora_id,
Some(storage_tier),
data_parallel_rank,
);
// Next block's parent is this block
current_parent = Some(block_hash.to_string());
}
}
VllmRawEvent::BlockRemoved {
block_hashes,
medium,
} => {
let storage_tier = medium
.as_ref()
.and_then(|m| StorageTier::from_vllm_medium(m))
.unwrap_or(StorageTier::Device);
tracing::debug!(
"Processing BlockRemoved: {} blocks, tier={:?}",
block_hashes.len(),
storage_tier
);
for block_hash in block_hashes {
tracker.handle_remove(
&block_hash.to_string(),
crate::block_manager::kv_consolidator::EventSource::Vllm,
);
}
}
VllmRawEvent::AllBlocksCleared {} => {
tracing::debug!("Processing AllBlocksCleared");
tracker.handle_clear_all();
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Cache Status Tracker
//!
//! Maintains the state of KV cache blocks across different event sources (vLLM and KVBM)
//! and determines when to emit STORE/REMOVE events.
//!
//! - Tracks by EVENT SOURCE (vLLM vs KVBM) instead of storage tier
//! - vLLM source: G1 (GPU) events from vLLM worker
//! - KVBM source: G2/G3 (host pinned/disk) events from KVBM
//! - Deduplication: Uses SequenceHash as the key
//! - Always computes sequence hash using KVBM's xxHash3 method, regardless of source
//! - SequenceHash = first block: Hash(tokens), subsequent: Hash([parent_seq_hash, block_hash])
//! - Emit Store: Only when a block is first stored from ANY source
//! - Emit Remove: Only when a block is removed from ALL sources
use std::collections::{HashMap, HashSet};
/// LocalBlockHash type (content hash from tokens only)
type LocalBlockHash = u64;
/// SequenceHash type (position-aware hash, includes parent context)
type SequenceHash = u64;
/// Seed for xxHash3 computation (must match the indexer's seed)
const XXH3_SEED: u64 = 1337;
/// Compute a LocalBlockHash from token IDs (content only)
fn compute_local_block_hash(token_ids: &[u32]) -> LocalBlockHash {
let bytes: Vec<u8> = token_ids
.iter()
.flat_map(|&num| num.to_le_bytes())
.collect();
xxhash_rust::xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED)
}
/// Compute a SequenceHash from parent sequence hash and current block hash
/// This mirrors the indexer's sequence hash computation for consistent tracking
///
/// For the first block (no parent): sequence_hash = block_hash
/// For subsequent blocks: sequence_hash = hash([parent_sequence_hash, current_block_hash])
fn compute_sequence_hash(
parent_sequence_hash: Option<SequenceHash>,
current_block_hash: LocalBlockHash,
) -> SequenceHash {
match parent_sequence_hash {
None => {
// First block: sequence hash equals block hash
current_block_hash
}
Some(parent_hash) => {
// Subsequent block: combine parent sequence hash with current block hash
let combined = [parent_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
xxhash_rust::xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED)
}
}
}
/// Event source for KV cache events
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EventSource {
/// Events from vLLM worker (G1/GPU)
Vllm,
/// Events from KVBM
Kvbm,
}
impl std::str::FromStr for EventSource {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"vllm" | "VLLM" | "GPU" => Ok(EventSource::Vllm),
"kvbm" | "KVBM" => Ok(EventSource::Kvbm),
_ => Err(format!("Unknown event source: {}", s)),
}
}
}
impl EventSource {
/// Convert to string representation
pub fn to_str(&self) -> &'static str {
match self {
EventSource::Vllm => "vllm",
EventSource::Kvbm => "kvbm",
}
}
}
/// Storage tier information (for metadata/debugging only)
///
/// Note: This does NOT determine the event source!
/// The event source (vLLM vs KVBM) is determined by WHERE the event originates:
/// - Events from vLLM's ZMQ subscriber → EventSource::Vllm
/// - Events from KVBM's DynamoEventManager → EventSource::Kvbm
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StorageTier {
Device, // GPU / G1
HostPinned, // CPU / G2
Disk, // Disk / G3
}
impl StorageTier {
/// Parse from vLLM's medium string (e.g., "GPU", "CPU_TIER1", "CPU_TIER2")
pub fn from_vllm_medium(s: &str) -> Option<Self> {
match s {
"GPU" => Some(StorageTier::Device),
"CPU_TIER1" => Some(StorageTier::HostPinned),
"CPU_TIER2" => Some(StorageTier::Disk),
_ => None,
}
}
/// Convert to vLLM's medium string
pub fn to_vllm_medium(&self) -> &'static str {
match self {
StorageTier::Device => "GPU",
StorageTier::HostPinned => "CPU_TIER1",
StorageTier::Disk => "CPU_TIER2",
}
}
/// Convert to string representation
pub fn to_str(&self) -> &'static str {
match self {
StorageTier::Device => "device",
StorageTier::HostPinned => "host_pinned",
StorageTier::Disk => "disk",
}
}
}
/// Legacy type alias for backward compatibility
#[deprecated(note = "Use StorageTier instead")]
pub type StorageMedium = StorageTier;
/// Minimal metadata for tracking which event sources have a block
/// All other metadata (tokens, parent, etc.) is stored in the ConsolidatedEvent when queued
#[derive(Debug, Clone)]
pub struct BlockMetadata {
/// Event sources where this block exists (vLLM and/or KVBM)
pub sources: HashSet<EventSource>,
/// The first external block hash seen for this token sequence (for output events)
/// Different sources may have different external hashes, but they all represent the same token content
pub first_block_hash: String,
}
impl BlockMetadata {
pub fn new(source: EventSource, block_hash: String) -> Self {
let mut sources = HashSet::new();
sources.insert(source);
Self {
sources,
first_block_hash: block_hash,
}
}
/// Check if this block exists in any source
pub fn exists_in_any_source(&self) -> bool {
!self.sources.is_empty()
}
/// Add a source to this block
/// Returns true if this is a new source (wasn't already present)
pub fn add_source(&mut self, source: EventSource) -> bool {
self.sources.insert(source)
}
/// Remove a source from this block
/// Returns true if the source was present and removed
pub fn remove_source(&mut self, source: EventSource) -> bool {
self.sources.remove(&source)
}
}
/// Event to be published to the router
#[derive(Debug, Clone)]
pub enum ConsolidatedEvent {
/// Block stored (first time across all sources)
Store {
block_hash: String,
parent_hash: Option<String>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<i32>,
source: String, // The source where it was first stored (vllm or kvbm)
},
/// Block removed (removed from all sources)
Remove {
block_hash: String,
source: String, // The source where it was last removed
},
/// All blocks cleared
ClearAll,
}
/// Cache Status Tracker
///
/// Deduplication logic:
/// - Uses SequenceHash (computed from tokens + parent) as the key for deduplication
/// - SequenceHash is position-aware: same tokens at different positions = different keys
/// - Always uses KVBM's xxHash3 hashing function, regardless of source
/// - This allows vLLM and KVBM blocks at the same position to be deduplicated
/// - Emit Store: Only when a block is first stored from ANY source
/// - Emit Remove: Only when a block is removed from ALL sources
#[derive(Debug)]
pub struct CacheStatusTracker {
/// Map of SequenceHash -> BlockMetadata (tracking which sources have this block)
/// The key is position-aware: includes parent context
blocks: HashMap<SequenceHash, BlockMetadata>,
/// Reverse mapping: external_block_hash -> SequenceHash (that we computed)
/// Needed because remove events only provide external hash, not token IDs
/// Maps each source's external hash to our computed sequence hash
hash_mapping: HashMap<String, SequenceHash>,
/// Queue of events to be published
event_queue: Vec<ConsolidatedEvent>,
}
impl Default for CacheStatusTracker {
fn default() -> Self {
Self::new()
}
}
impl CacheStatusTracker {
pub fn new() -> Self {
Self {
blocks: HashMap::new(),
hash_mapping: HashMap::new(),
event_queue: Vec::new(),
}
}
/// Handle a STORE event
///
/// Returns true if a consolidated STORE event should be published.
/// Only publishes when a block is stored for the FIRST TIME from ANY source.
///
/// # Arguments
/// * `block_hash` - The external block hash (from vLLM or KVBM)
/// * `source` - The event source (vLLM or KVBM) that stored this block
/// * `token_ids` - The token IDs in this block (for content-based deduplication)
/// * `tier` - Optional storage tier information (for metadata/debugging)
#[allow(clippy::too_many_arguments)]
pub fn handle_store(
&mut self,
block_hash: String,
source: EventSource,
token_ids: Vec<u32>,
parent_hash: Option<String>,
block_size: usize,
lora_id: Option<i32>,
tier: Option<StorageTier>,
data_parallel_rank: Option<i32>,
) -> bool {
// Compute LocalBlockHash from token IDs (content only)
let local_block_hash = compute_local_block_hash(&token_ids);
// Resolve parent sequence hash from parent's external hash (if provided)
let parent_sequence_hash = parent_hash
.as_ref()
.and_then(|ph| self.hash_mapping.get(ph).copied());
// Compute SequenceHash using KVBM's hashing method (position-aware)
// This ensures consistent deduplication regardless of source
let sequence_hash = compute_sequence_hash(parent_sequence_hash, local_block_hash);
tracing::debug!(
"Computing sequence_hash for block: local_block_hash={}, parent_seq_hash={:?}, sequence_hash={}",
local_block_hash,
parent_sequence_hash,
sequence_hash
);
if let Some(metadata) = self.blocks.get_mut(&sequence_hash) {
// Block already exists from another source (deduplication!), just add the new source
let is_new_source = metadata.add_source(source);
// Add this external hash to the mapping so remove events from this source can find the block
// Multiple external hashes (from different sources) can map to the same SequenceHash
self.hash_mapping.insert(block_hash.clone(), sequence_hash);
if is_new_source {
tracing::debug!(
"DEDUP: Block {} (seq_hash={}) added to source {:?} (already exists in {} source(s), {} tokens, external_hash={})\n Token IDs: {:?}",
&metadata.first_block_hash[..16.min(metadata.first_block_hash.len())],
sequence_hash,
source,
metadata.sources.len(),
token_ids.len(),
&block_hash[..16.min(block_hash.len())],
&token_ids
);
} else {
tracing::debug!(
"Block {} (seq_hash={}) already in source {:?}, external_hash={}\n Token IDs: {:?}",
&metadata.first_block_hash[..16.min(metadata.first_block_hash.len())],
sequence_hash,
source,
&block_hash[..16.min(block_hash.len())],
&token_ids
);
}
// Don't publish a new STORE event (block already exists)
false
} else {
// First time seeing this block from any source - create metadata and queue STORE event
let metadata = BlockMetadata::new(source, block_hash.clone());
tracing::debug!(
"New block {} (seq_hash={}) stored in source {:?} (tier={:?}): {} tokens, block_size={}, parent={}, lora={:?}, dp_rank={:?}\n Token IDs: {:?}",
&block_hash[..16.min(block_hash.len())],
sequence_hash,
source,
tier,
token_ids.len(),
block_size,
parent_hash
.as_ref()
.map(|p| &p[..16.min(p.len())])
.unwrap_or("none"),
lora_id,
data_parallel_rank,
&token_ids
);
self.blocks.insert(sequence_hash, metadata);
// Add to hash mapping so remove events can find the block by external hash
self.hash_mapping.insert(block_hash.clone(), sequence_hash);
// Queue a STORE event with full metadata
self.event_queue.push(ConsolidatedEvent::Store {
block_hash: block_hash.clone(),
parent_hash,
token_ids,
block_size,
lora_id,
source: source.to_str().to_string(),
});
tracing::debug!(
"Block {} (seq_hash={}) stored in first source {:?}, will publish STORE event (total tracked: {}, hash_mapping: {})",
block_hash,
sequence_hash,
source,
self.blocks.len(),
self.hash_mapping.len()
);
true
}
}
/// Handle a REMOVE event
///
/// Returns true if a consolidated REMOVE event should be published.
/// Only publishes when a block is removed from ALL sources.
///
/// # Arguments
/// * `block_hash` - The external block hash to remove
/// * `source` - The event source (vLLM or KVBM) that removed this block
pub fn handle_remove(&mut self, block_hash: &str, source: EventSource) -> bool {
// Look up the SequenceHash from the external block hash
let sequence_hash = match self.hash_mapping.get(block_hash) {
Some(&hash) => hash,
None => {
tracing::warn!(
"Attempted to remove unknown block {} from source {:?} (not in hash_mapping)",
block_hash,
source
);
return false;
}
};
if let Some(metadata) = self.blocks.get_mut(&sequence_hash) {
// Remove the source
let was_removed = metadata.remove_source(source);
if !was_removed {
tracing::warn!(
"Attempted to remove source {:?} from block {} but it wasn't present",
source,
block_hash
);
return false;
}
// Remove this external hash immediately when the source removes it
// This keeps hash_mapping clean
// Each external hash belongs to exactly one source, so when that source
// removes the block, we can safely remove the hash_mapping entry
self.hash_mapping.remove(block_hash);
tracing::debug!(
"Removed hash_mapping entry for {} (hash_mapping size: {})",
block_hash,
self.hash_mapping.len()
);
// Check if this was the last source
if !metadata.exists_in_any_source() {
// Block is gone from all sources - remove from tracker and publish REMOVE
let first_block_hash = metadata.first_block_hash.clone();
self.blocks.remove(&sequence_hash);
// Double-check: clean up any stray hash mappings (should be empty by now)
// This is a safety check
let stray_count_before = self.hash_mapping.len();
self.hash_mapping
.retain(|_ext_hash, &mut seq_hash| seq_hash != sequence_hash);
let stray_count = stray_count_before - self.hash_mapping.len();
if stray_count > 0 {
tracing::warn!(
"Found {} stray hash_mapping entries for seq_hash={} after all sources removed - cleaned up (hash_mapping size now: {})",
stray_count,
sequence_hash,
self.hash_mapping.len()
);
}
self.event_queue.push(ConsolidatedEvent::Remove {
block_hash: first_block_hash.clone(),
source: source.to_str().to_string(),
});
tracing::debug!(
"Block {} (seq_hash={}) removed from last source {:?}, will publish REMOVE event (total tracked: {}, hash_mapping: {})",
first_block_hash,
sequence_hash,
source,
self.blocks.len(),
self.hash_mapping.len()
);
true
} else {
// Block still exists in other sources
tracing::debug!(
"Block {} (seq_hash={}) removed from source {:?}, still in {} source(s): {:?} (hash_mapping: {})",
&metadata.first_block_hash[..16.min(metadata.first_block_hash.len())],
sequence_hash,
source,
metadata.sources.len(),
metadata.sources,
self.hash_mapping.len()
);
false
}
} else {
tracing::warn!(
"Attempted to remove block {} from source {:?} but block not tracked",
&block_hash[..16.min(block_hash.len())],
source
);
false
}
}
/// Handle a CLEAR_ALL event
pub fn handle_clear_all(&mut self) {
let num_blocks = self.blocks.len();
tracing::debug!("Clearing all {} blocks from tracker", num_blocks);
self.blocks.clear();
self.hash_mapping.clear();
self.event_queue.push(ConsolidatedEvent::ClearAll);
}
/// Drain all pending events to be published
pub fn drain_events(&mut self) -> Vec<ConsolidatedEvent> {
let events = std::mem::take(&mut self.event_queue);
if !events.is_empty() {
tracing::debug!(
"Draining {} pending kv event(s) for publishing",
events.len()
);
}
events
}
/// Get the number of tracked blocks
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
/// Get sources for a specific block by external block hash
pub fn get_block_sources(&self, external_block_hash: &str) -> Option<&HashSet<EventSource>> {
// Look up the local hash from external hash, then get sources
let local_hash = self.hash_mapping.get(external_block_hash)?;
self.blocks.get(local_hash).map(|m| &m.sources)
}
/// Legacy method for backwards compatibility
#[deprecated(note = "Use get_block_sources instead")]
pub fn get_block_tiers(&self, block_hash: &str) -> Option<&HashSet<EventSource>> {
self.get_block_sources(block_hash)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_first_store_publishes() {
let mut tracker = CacheStatusTracker::new();
let should_publish = tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None, // data_parallel_rank
);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 1);
assert_eq!(tracker.drain_events().len(), 1);
}
#[test]
fn test_duplicate_store_no_publish() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events(); // Clear first event
let should_publish = tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
assert!(!should_publish);
assert_eq!(tracker.drain_events().len(), 0);
}
#[test]
fn test_multi_source_store() {
let mut tracker = CacheStatusTracker::new();
// First store from vLLM
tracker.handle_store(
"vllm_hash1".to_string(), // vLLM's external hash
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
// Second store from KVBM - should not publish (same tokens, different external hash)
let should_publish = tracker.handle_store(
"kvbm_hash1".to_string(), // KVBM's external hash (different from vLLM)
EventSource::Kvbm,
vec![1, 2, 3], // Same tokens!
None,
3,
None,
Some(StorageTier::HostPinned),
None,
);
assert!(!should_publish);
#[allow(deprecated)]
let sources = tracker.get_block_tiers("vllm_hash1").unwrap();
assert_eq!(sources.len(), 2); // vllm and kvbm
}
#[test]
fn test_remove_from_single_source_publishes() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
let should_publish = tracker.handle_remove("hash1", EventSource::Vllm);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 0);
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
matches!(events[0], ConsolidatedEvent::Remove { .. });
}
#[test]
fn test_remove_from_multi_source_no_publish() {
let mut tracker = CacheStatusTracker::new();
// Store from vLLM - first STORE event published
tracker.handle_store(
"vllm_hash1".to_string(), // vLLM's external hash
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
// Store from KVBM - no event, block already exists (same tokens, different external hash)
tracker.handle_store(
"kvbm_hash1".to_string(), // KVBM's external hash (different from vLLM)
EventSource::Kvbm,
vec![1, 2, 3], // Same tokens!
None,
3,
None,
Some(StorageTier::HostPinned),
None,
);
tracker.drain_events();
// Remove from vLLM - should not publish (still in KVBM)
let should_publish = tracker.handle_remove("vllm_hash1", EventSource::Vllm);
assert!(!should_publish);
assert_eq!(tracker.num_blocks(), 1);
assert_eq!(tracker.drain_events().len(), 0);
// Remove from KVBM (last source) - should publish REMOVE event
let should_publish = tracker.handle_remove("kvbm_hash1", EventSource::Kvbm);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 0);
}
#[test]
fn test_sequence_hash_first_block() {
let mut tracker = CacheStatusTracker::new();
// First block (no parent)
let should_publish = tracker.handle_store(
"block1".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None, // No parent
4,
None,
Some(StorageTier::Device),
None,
);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 1);
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
}
#[test]
fn test_sequence_hash_with_parent() {
let mut tracker = CacheStatusTracker::new();
// First block
tracker.handle_store(
"block1".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
// Second block with parent
let should_publish = tracker.handle_store(
"block2".to_string(),
EventSource::Vllm,
vec![5, 6, 7, 8],
Some("block1".to_string()), // Has parent
4,
None,
Some(StorageTier::Device),
None,
);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 2);
}
#[test]
fn test_same_tokens_different_position_different_blocks() {
let mut tracker = CacheStatusTracker::new();
// First occurrence: tokens [1,2,3,4] at position 0 (no parent)
tracker.handle_store(
"block1".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
// Second occurrence: SAME tokens [1,2,3,4] but at position 1 (with parent)
// This should be treated as a DIFFERENT block due to sequence hash
let should_publish = tracker.handle_store(
"block2".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4], // Same tokens!
Some("block1".to_string()), // But different parent
4,
None,
Some(StorageTier::Device),
None,
);
// Should publish because sequence hash differs (different position)
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 2);
}
#[test]
fn test_clear_all() {
let mut tracker = CacheStatusTracker::new();
// Add multiple blocks
tracker.handle_store(
"block1".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
tracker.handle_store(
"block2".to_string(),
EventSource::Kvbm,
vec![5, 6, 7, 8],
None,
4,
None,
Some(StorageTier::HostPinned),
None,
);
assert_eq!(tracker.num_blocks(), 2);
// Clear all
tracker.handle_clear_all();
assert_eq!(tracker.num_blocks(), 0);
// Verify hash_mapping is also cleared
let should_publish = tracker.handle_remove("block1", EventSource::Vllm);
assert!(!should_publish); // Should fail because block is gone
}
#[test]
fn test_deduplication_across_sources_with_parent() {
let mut tracker = CacheStatusTracker::new();
// vLLM stores block 1 (parent)
tracker.handle_store(
"vllm_parent".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
// vLLM stores block 2 (child of block 1)
tracker.handle_store(
"vllm_child".to_string(),
EventSource::Vllm,
vec![5, 6, 7, 8],
Some("vllm_parent".to_string()),
4,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
// KVBM stores the SAME child block (same tokens, same parent)
// but with a different external hash
let should_publish = tracker.handle_store(
"kvbm_child".to_string(), // Different external hash
EventSource::Kvbm,
vec![5, 6, 7, 8], // Same tokens
Some("vllm_parent".to_string()), // Same parent
4,
None,
Some(StorageTier::HostPinned),
None,
);
// Should NOT publish (deduplication)
assert!(!should_publish);
// Should still have 2 blocks (parent + child)
assert_eq!(tracker.num_blocks(), 2);
}
#[test]
fn test_remove_non_existent_block() {
let mut tracker = CacheStatusTracker::new();
let should_publish = tracker.handle_remove("non_existent", EventSource::Vllm);
assert!(!should_publish);
assert_eq!(tracker.num_blocks(), 0);
}
#[test]
fn test_compute_local_block_hash_deterministic() {
let tokens1 = vec![1, 2, 3, 4];
let tokens2 = vec![1, 2, 3, 4];
let tokens3 = vec![1, 2, 3, 5]; // Different
let hash1 = compute_local_block_hash(&tokens1);
let hash2 = compute_local_block_hash(&tokens2);
let hash3 = compute_local_block_hash(&tokens3);
// Same tokens should produce same hash
assert_eq!(hash1, hash2);
// Different tokens should produce different hash
assert_ne!(hash1, hash3);
}
#[test]
fn test_compute_sequence_hash_deterministic() {
let block_hash1 = compute_local_block_hash(&[1, 2, 3, 4]);
let block_hash2 = compute_local_block_hash(&[5, 6, 7, 8]);
// First block: sequence hash = block hash
let seq_hash1 = compute_sequence_hash(None, block_hash1);
assert_eq!(seq_hash1, block_hash1);
// Second block with parent
let seq_hash2_v1 = compute_sequence_hash(Some(seq_hash1), block_hash2);
let seq_hash2_v2 = compute_sequence_hash(Some(seq_hash1), block_hash2);
// Same parent + block should produce same sequence hash
assert_eq!(seq_hash2_v1, seq_hash2_v2);
// Same block but different parent should produce different sequence hash
let different_parent = compute_local_block_hash(&[9, 10, 11, 12]);
let seq_hash2_different = compute_sequence_hash(Some(different_parent), block_hash2);
assert_ne!(seq_hash2_v1, seq_hash2_different);
}
}
......@@ -100,7 +100,7 @@ impl<R: LogicalResources, Metadata: BlockMetadata>
{
pub async fn new(config: KvBlockManagerConfig, logical_resources: R) -> Result<Arc<Self>> {
let model_config = config.model.clone();
let mut resources = Resources::new(config)?;
let mut resources = Resources::new(config).await?;
let block_data_factories =
logical::LogicalBlockFactories::new(&mut resources, logical_resources)?;
......@@ -220,7 +220,7 @@ impl<R: LogicalResources, Metadata: BlockMetadata>
impl<Metadata: BlockMetadata> KvBlockManagerState<locality::Local, Metadata> {
pub async fn new(config: KvBlockManagerConfig) -> Result<Arc<Self>> {
let model_config = config.model.clone();
let mut resources = Resources::new(config)?;
let mut resources = Resources::new(config).await?;
let block_data_factories = local::LocalBlockDataFactories::new(&mut resources)?;
let (mut local_block_set, disk_factory, host_factory, device_factory) =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment