Unverified Commit 038b50d2 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: add standalone KV indexer with query endpoint [DYN-2164] (#6446)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 71be641d
...@@ -140,6 +140,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -140,6 +140,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
} }
m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?; m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?;
m.add_function(wrap_pyfunction!(llm::kv::start_kv_block_indexer_py, m)?)?;
m.add_function(wrap_pyfunction!(lora_name_to_id, m)?)?; m.add_function(wrap_pyfunction!(lora_name_to_id, m)?)?;
m.add_function(wrap_pyfunction!(log_message, m)?)?; m.add_function(wrap_pyfunction!(log_message, m)?)?;
m.add_function(wrap_pyfunction!(register_model, m)?)?; m.add_function(wrap_pyfunction!(register_model, m)?)?;
......
...@@ -26,6 +26,26 @@ fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<Blo ...@@ -26,6 +26,26 @@ fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<Blo
depythonize(obj).map_err(to_pyerr) depythonize(obj).map_err(to_pyerr)
} }
#[pyfunction]
#[pyo3(name = "start_kv_block_indexer", signature = (endpoint, block_size, kv_router_config))]
pub fn start_kv_block_indexer_py<'p>(
py: Python<'p>,
endpoint: &Endpoint,
block_size: u32,
kv_router_config: &super::entrypoint::KvRouterConfig,
) -> PyResult<Bound<'p, PyAny>> {
let component = endpoint.inner.component().clone();
let config = kv_router_config.inner();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
llm_rs::kv_router::indexer_standalone::start_kv_block_indexer(
&component, &config, block_size,
)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
#[pyfunction] #[pyfunction]
#[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))] #[pyo3(name = "compute_block_hash_for_seq", signature = (tokens, kv_block_size, block_mm_infos=None))]
pub fn compute_block_hash_for_seq_py( pub fn compute_block_hash_for_seq_py(
......
...@@ -1044,6 +1044,12 @@ class KvRouterConfig: ...@@ -1044,6 +1044,12 @@ class KvRouterConfig:
""" """
... ...
async def start_kv_block_indexer(
endpoint: Endpoint,
block_size: int,
kv_router_config: KvRouterConfig,
) -> None: ...
async def register_model( async def register_model(
model_input: ModelInput, model_input: ModelInput,
model_type: ModelType, model_type: ModelType,
......
...@@ -12,7 +12,9 @@ For public APIs, use dynamo.runtime and dynamo.llm. ...@@ -12,7 +12,9 @@ For public APIs, use dynamo.runtime and dynamo.llm.
# Re-export from _core # Re-export from _core
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import start_kv_block_indexer as start_kv_block_indexer
__all__ = [ __all__ = [
"ModelDeploymentCard", "ModelDeploymentCard",
"start_kv_block_indexer",
] ]
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
# Type stubs - re-export from _core # Type stubs - re-export from _core
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import start_kv_block_indexer as start_kv_block_indexer
__all__ = [ __all__ = [
"ModelDeploymentCard", "ModelDeploymentCard",
"start_kv_block_indexer",
] ]
...@@ -29,6 +29,7 @@ pub use dynamo_kv_router::indexer; ...@@ -29,6 +29,7 @@ pub use dynamo_kv_router::indexer;
pub use dynamo_kv_router::protocols; pub use dynamo_kv_router::protocols;
pub mod config; pub mod config;
pub mod indexer_standalone;
pub mod metrics; pub mod metrics;
pub mod prefill_router; pub mod prefill_router;
pub mod publisher; pub mod publisher;
...@@ -41,6 +42,7 @@ pub mod subscriber; ...@@ -41,6 +42,7 @@ pub mod subscriber;
pub mod worker_query; pub mod worker_query;
pub use config::{KvRouterConfig, RouterConfigOverride}; pub use config::{KvRouterConfig, RouterConfigOverride};
pub use indexer_standalone::start_kv_block_indexer;
pub use prefill_router::PrefillRouter; pub use prefill_router::PrefillRouter;
pub use push_router::{DirectRoutingRouter, KvPushRouter}; pub use push_router::{DirectRoutingRouter, KvPushRouter};
...@@ -78,6 +80,9 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events"; ...@@ -78,6 +80,9 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
pub const RADIX_STATE_BUCKET: &str = "radix-bucket"; pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state"; pub const RADIX_STATE_FILE: &str = "radix-state";
// for standalone indexer query
pub const KV_INDEXER_QUERY_ENDPOINT: &str = "kv_indexer_query";
// for worker-local kvindexer query // for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use serde::{Deserialize, Serialize};
use dynamo_runtime::{
component::Component,
pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait,
network::Ingress,
},
protocols::{annotated::Annotated, maybe_error::MaybeError},
stream,
};
use crate::kv_router::{
Indexer, KV_INDEXER_QUERY_ENDPOINT, KvRouterConfig,
protocols::{
BlockExtraInfo, LocalBlockHash, OverlapScores, RouterEvent, compute_block_hash_for_seq,
},
subscriber,
};
#[derive(Serialize, Deserialize, Debug)]
pub enum IndexerQueryRequest {
FindMatchesHashed {
block_hashes: Vec<LocalBlockHash>,
},
FindMatchesTokens {
tokens: Vec<u32>,
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
},
DumpTree,
}
#[derive(Serialize, Deserialize, Debug)]
pub enum IndexerQueryResponse {
Matches(OverlapScores),
TreeDump(Vec<RouterEvent>),
Error(String),
}
impl MaybeError for IndexerQueryResponse {
fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
IndexerQueryResponse::Error(err.to_string())
}
fn err(&self) -> Option<anyhow::Error> {
match self {
IndexerQueryResponse::Error(msg) => Some(anyhow::Error::msg(msg.clone())),
_ => None,
}
}
}
struct IndexerQueryEngine {
indexer: Indexer,
block_size: u32,
}
#[async_trait]
impl
AsyncEngine<
SingleIn<IndexerQueryRequest>,
ManyOut<Annotated<IndexerQueryResponse>>,
anyhow::Error,
> for IndexerQueryEngine
{
async fn generate(
&self,
request: SingleIn<IndexerQueryRequest>,
) -> Result<ManyOut<Annotated<IndexerQueryResponse>>> {
let (request, ctx) = request.into_parts();
if matches!(request, IndexerQueryRequest::DumpTree) {
let response = match self.indexer.dump_events().await {
Ok(events) => IndexerQueryResponse::TreeDump(events),
Err(e) => IndexerQueryResponse::Error(format!("{e:?}")),
};
return Ok(ResponseStream::new(
Box::pin(stream::iter(vec![Annotated::from_data(response)])),
ctx.context(),
));
}
let block_hashes = match request {
IndexerQueryRequest::FindMatchesHashed { block_hashes } => block_hashes,
IndexerQueryRequest::FindMatchesTokens {
tokens,
block_mm_infos,
} => compute_block_hash_for_seq(&tokens, self.block_size, block_mm_infos.as_deref()),
IndexerQueryRequest::DumpTree => unreachable!(),
};
let response = match self.indexer.find_matches(block_hashes).await {
Ok(scores) => IndexerQueryResponse::Matches(scores),
Err(e) => IndexerQueryResponse::Error(format!("{e:?}")),
};
Ok(ResponseStream::new(
Box::pin(stream::iter(vec![Annotated::from_data(response)])),
ctx.context(),
))
}
}
async fn start_indexer_query_endpoint(
component: Component,
indexer: Indexer,
block_size: u32,
) -> Result<()> {
let engine = std::sync::Arc::new(IndexerQueryEngine {
indexer,
block_size,
});
let ingress = Ingress::for_engine(engine)?;
let fut = component
.endpoint(KV_INDEXER_QUERY_ENDPOINT)
.endpoint_builder()
.handler(ingress)
.graceful_shutdown(true)
.start();
tokio::spawn(async move {
if let Err(e) = fut.await {
tracing::error!("Indexer query endpoint failed: {e:?}");
}
});
Ok(())
}
pub async fn start_kv_block_indexer(
component: &Component,
kv_router_config: &KvRouterConfig,
block_size: u32,
) -> Result<Indexer> {
if kv_router_config.durable_kv_events {
anyhow::bail!(
"standalone indexer does not support durable_kv_events (JetStream): \
consumer ID collisions, orphan cleanup conflicts, and snapshot/purge races \
make it incompatible with an independent indexer"
);
}
let indexer = Indexer::new(component, kv_router_config, block_size);
subscriber::start_subscriber(component.clone(), kv_router_config, indexer.clone()).await?;
start_indexer_query_endpoint(component.clone(), indexer.clone(), block_size).await?;
tracing::info!(
"Standalone KV indexer started with query endpoint '{KV_INDEXER_QUERY_ENDPOINT}'"
);
Ok(indexer)
}
...@@ -62,12 +62,21 @@ impl Client { ...@@ -62,12 +62,21 @@ impl Client {
); );
let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?; let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
let (avail_tx, avail_rx) = tokio::sync::watch::channel(vec![]); // Seed instance_avail from the current instance_source snapshot so that
// callers who proceed immediately after wait_for_instances (which reads
// instance_source directly) will also find instances in instance_avail
// (which is read by the routing methods like random/round_robin).
let initial_ids: Vec<u64> = instance_source
.borrow()
.iter()
.map(|instance| instance.id())
.collect();
let (avail_tx, avail_rx) = tokio::sync::watch::channel(initial_ids.clone());
let client = Client { let client = Client {
endpoint: endpoint.clone(), endpoint: endpoint.clone(),
instance_source: instance_source.clone(), instance_source: instance_source.clone(),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))), instance_avail: Arc::new(ArcSwap::from(Arc::new(initial_ids.clone()))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))), instance_free: Arc::new(ArcSwap::from(Arc::new(initial_ids))),
instance_avail_tx: Arc::new(avail_tx), instance_avail_tx: Arc::new(avail_tx),
instance_avail_rx: avail_rx, instance_avail_rx: avail_rx,
reconcile_interval, reconcile_interval,
......
...@@ -38,6 +38,14 @@ static ACTUAL_TCP_RPC_PORT: OnceLock<u16> = OnceLock::new(); ...@@ -38,6 +38,14 @@ static ACTUAL_TCP_RPC_PORT: OnceLock<u16> = OnceLock::new();
static GLOBAL_TCP_SERVER: tokio::sync::OnceCell<Arc<SharedTcpServer>> = static GLOBAL_TCP_SERVER: tokio::sync::OnceCell<Arc<SharedTcpServer>> =
tokio::sync::OnceCell::const_new(); tokio::sync::OnceCell::const_new();
/// Process-wide cancellation token for the global TCP server.
///
/// This token is independent of any individual runtime's cancellation token so that
/// component Drop impls (e.g. KvRouter::drop → cancel) don't kill the shared accept
/// loop while the OnceCell still hands out the (now-dead) server to later runtimes.
static GLOBAL_TCP_SERVER_TOKEN: std::sync::LazyLock<CancellationToken> =
std::sync::LazyLock::new(CancellationToken::new);
/// Get the actual TCP RPC port that the server is listening on. /// Get the actual TCP RPC port that the server is listening on.
pub fn get_actual_tcp_rpc_port() -> anyhow::Result<u16> { pub fn get_actual_tcp_rpc_port() -> anyhow::Result<u16> {
ACTUAL_TCP_RPC_PORT.get().copied().ok_or_else(|| { ACTUAL_TCP_RPC_PORT.get().copied().ok_or_else(|| {
...@@ -328,7 +336,7 @@ impl NetworkManager { ...@@ -328,7 +336,7 @@ impl NetworkManager {
"Creating TCP request plane server" "Creating TCP request plane server"
); );
let server = SharedTcpServer::new(bind_addr, self.cancellation_token.clone()); let server = SharedTcpServer::new(bind_addr, GLOBAL_TCP_SERVER_TOKEN.clone());
// Bind and start server, getting the actual bound address // Bind and start server, getting the actual bound address
let actual_addr = server.clone().bind_and_start().await?; let actual_addr = server.clone().bind_and_start().await?;
......
...@@ -13,7 +13,9 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -13,7 +13,9 @@ from typing import TYPE_CHECKING, Any, Optional
import aiohttp import aiohttp
import nats import nats
from dynamo._core import DistributedRuntime, KvRouter, KvRouterConfig from dynamo._internal import start_kv_block_indexer
from dynamo.llm import KvRouter, KvRouterConfig
from dynamo.runtime import DistributedRuntime
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -115,6 +117,46 @@ def generate_random_suffix() -> str: ...@@ -115,6 +117,46 @@ def generate_random_suffix() -> str:
return "".join(random.choices(string.ascii_lowercase, k=10)) # noqa: S311 return "".join(random.choices(string.ascii_lowercase, k=10)) # noqa: S311
def assert_event_dumps_equal(
expected: list[dict],
actual: list[dict],
expected_label: str,
actual_label: str,
) -> None:
"""Assert two sorted event dump lists are equal, ignoring event_id fields."""
assert len(expected) == len(actual), (
f"{expected_label} has {len(expected)} events, "
f"{actual_label} has {len(actual)} events"
)
differences = []
for i, (exp_item, act_item) in enumerate(zip(expected, actual)):
exp_compare = exp_item.copy()
act_compare = act_item.copy()
if "event" in exp_compare and "event_id" in exp_compare["event"]:
del exp_compare["event"]["event_id"]
if "event" in act_compare and "event_id" in act_compare["event"]:
del act_compare["event"]["event_id"]
if exp_compare != act_compare:
differences.append(
{"index": i, expected_label: exp_item, actual_label: act_item}
)
if differences:
error_msg = (
f"{expected_label} and {actual_label} differ. "
f"Found {len(differences)} differences:\n"
)
for diff in differences:
error_msg += f"\nDifference at index {diff['index']}:\n"
error_msg += (
f"{expected_label}: {json.dumps(diff[expected_label], indent=2)}\n"
)
error_msg += f"{actual_label}: {json.dumps(diff[actual_label], indent=2)}\n"
error_msg += "-" * 80 + "\n"
assert False, error_msg
def verify_response_worker_ids( def verify_response_worker_ids(
response_worker_ids: list[dict[str, Optional[int]]], response_worker_ids: list[dict[str, Optional[int]]],
key: str, key: str,
...@@ -1568,57 +1610,49 @@ def _test_router_indexers_sync( ...@@ -1568,57 +1610,49 @@ def _test_router_indexers_sync(
sorted_state1 = sorted(state1, key=sort_key) sorted_state1 = sorted(state1, key=sort_key)
sorted_state2 = sorted(state2, key=sort_key) sorted_state2 = sorted(state2, key=sort_key)
# Verify they are equal
logger.info(f"Router 1 has {len(sorted_state1)} events") logger.info(f"Router 1 has {len(sorted_state1)} events")
logger.info(f"Router 2 has {len(sorted_state2)} events") logger.info(f"Router 2 has {len(sorted_state2)} events")
# Compare states one by one and only show differences assert_event_dumps_equal(sorted_state1, sorted_state2, "Router 1", "Router 2")
if len(sorted_state1) != len(sorted_state2): logger.info("Successfully verified that both router states are equal")
logger.error(
f"Router 1 has {len(sorted_state1)} events, Router 2 has {len(sorted_state2)} events" # Verify standalone indexer builds the same tree (only for non-durable/NATS Core)
if not durable_kv_events:
logger.info("Starting standalone indexer and verifying tree state")
runtime3 = get_runtime(store_backend, request_plane)
endpoint3 = runtime3.endpoint(
f"{engine_workers.namespace}.{engine_workers.component_name}.generate"
) )
assert False, "Router states have different numbers of events" await start_kv_block_indexer(endpoint3, block_size, kv_router_config)
differences = [] # Wait for the standalone indexer to sync events from workers
for i, (state1_item, state2_item) in enumerate( await asyncio.sleep(3)
zip(sorted_state1, sorted_state2)
): # Query the standalone indexer's tree via kv_indexer_query endpoint
# Create copies without event_id for comparison # Note: reuse runtime3 to keep the standalone indexer's component alive
item1_compare = state1_item.copy() query_endpoint = runtime3.endpoint(
item2_compare = state2_item.copy() f"{engine_workers.namespace}.{engine_workers.component_name}.kv_indexer_query"
# Remove event_id from the nested event structure
if "event" in item1_compare and "event_id" in item1_compare["event"]:
del item1_compare["event"]["event_id"]
if "event" in item2_compare and "event_id" in item2_compare["event"]:
del item2_compare["event"]["event_id"]
if item1_compare != item2_compare:
differences.append(
{
"index": i,
"router1_state": state1_item,
"router2_state": state2_item,
}
)
# If there are differences, format them for easier debugging
if differences:
error_msg = (
f"Router states are not equal. Found {len(differences)} differences:\n"
) )
for diff in differences: query_client = await query_endpoint.client()
error_msg += f"\nDifference at index {diff['index']}:\n" await query_client.wait_for_instances()
error_msg += ( stream = await query_client.generate("DumpTree", annotated=False)
f"Router 1: {json.dumps(diff['router1_state'], indent=2)}\n" response = await stream.__anext__()
) standalone_state = response["TreeDump"]
error_msg += (
f"Router 2: {json.dumps(diff['router2_state'], indent=2)}\n"
)
error_msg += "-" * 80 + "\n"
assert False, error_msg sorted_standalone = sorted(standalone_state, key=sort_key)
logger.info("Successfully verified that both router states are equal") logger.info(f"Standalone indexer has {len(sorted_standalone)} events")
assert_event_dumps_equal(
sorted_state1, sorted_standalone, "Router 1", "Standalone"
)
logger.info(
"Successfully verified standalone indexer state matches router states"
)
else:
logger.info(
"Skipping standalone indexer verification (not supported with durable_kv_events)"
)
# Verify NATS consumers are created (while routers are still alive) # Verify NATS consumers are created (while routers are still alive)
# Skip this for NATS interruption test since it uses local indexer (NATS Core, not JetStream) # Skip this for NATS interruption test since it uses local indexer (NATS Core, not JetStream)
......
...@@ -537,7 +537,7 @@ def test_kv_router_bindings( ...@@ -537,7 +537,7 @@ def test_kv_router_bindings(
], ],
indirect=["request_plane", "durable_kv_events"], indirect=["request_plane", "durable_kv_events"],
) )
@pytest.mark.timeout(180) # bumped for xdist contention (was 90s; up to 33s under load) @pytest.mark.timeout(180)
def test_indexers_sync( def test_indexers_sync(
request, request,
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment